diff --git a/.github/workflows/build-ce7-releases.yml b/.github/workflows/build-ce7-releases.yml index cf8fda8f8..fb51a3e1e 100644 --- a/.github/workflows/build-ce7-releases.yml +++ b/.github/workflows/build-ce7-releases.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: sparkver: [spark303, spark333] - blazever: [2.0.9.1] + blazever: [3.0.0] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/tpcds.yml b/.github/workflows/tpcds.yml index 98722c4ff..2b280d9d7 100644 --- a/.github/workflows/tpcds.yml +++ b/.github/workflows/tpcds.yml @@ -34,19 +34,18 @@ jobs: with: {version: "21.7"} - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: {rustflags: --allow warnings -C target-cpu=native} + with: + toolchain: nightly + rustflags: --allow warnings -C target-feature=+aes + components: + cargo + rustfmt - name: Rustfmt Check uses: actions-rust-lang/rustfmt@v1 - ## - name: Rust Clippy Check - ## uses: actions-rs/clippy-check@v1 - ## with: - ## token: ${{ secrets.GITHUB_TOKEN }} - ## args: --all-features - - name: Cargo test - run: cargo test --workspace --all-features + run: cargo +nightly test --workspace --all-features - name: Build Spark303 run: mvn package -Ppre -Pspark303 diff --git a/Cargo.lock b/Cargo.lock index c5f0f8587..89905762b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -97,7 +97,7 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-arith", "arrow-array", @@ -117,7 +117,7 @@ dependencies = [ [[package]] name = "arrow-arith" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -131,7 +131,7 @@ dependencies = [ [[package]] name = "arrow-array" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-buffer", @@ -147,7 +147,7 @@ dependencies = [ [[package]] name = "arrow-buffer" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "bytes", "half", @@ -157,7 +157,7 @@ dependencies = [ [[package]] name = "arrow-cast" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -175,7 +175,7 @@ dependencies = [ [[package]] name = "arrow-csv" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -193,7 +193,7 @@ dependencies = [ [[package]] name = "arrow-data" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-buffer", "arrow-schema", @@ -204,7 +204,7 @@ dependencies = [ [[package]] name = "arrow-ipc" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -218,7 +218,7 @@ dependencies = [ [[package]] name = "arrow-json" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -237,7 +237,7 @@ dependencies = [ [[package]] name = "arrow-ord" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -251,7 +251,7 @@ dependencies = [ [[package]] name = "arrow-row" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-array", @@ -265,7 +265,7 @@ dependencies = [ [[package]] name = "arrow-schema" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "bitflags 2.5.0", "serde", @@ -274,7 +274,7 @@ dependencies = [ [[package]] name = "arrow-select" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-array", @@ -287,7 +287,7 @@ dependencies = [ [[package]] name = "arrow-string" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -751,7 +751,7 @@ dependencies = [ [[package]] name = "datafusion" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -800,7 +800,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -819,7 +819,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "chrono", @@ -839,7 +839,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -923,6 +923,7 @@ dependencies = [ "arrow", "async-trait", "base64 0.22.1", + "bitvec", "blaze-jni-bridge", "byteorder", "bytes", @@ -957,7 +958,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "base64 0.21.7", @@ -971,7 +972,7 @@ dependencies = [ [[package]] name = "datafusion-functions-array" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "datafusion-common", @@ -984,7 +985,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "async-trait", @@ -1001,7 +1002,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -1036,7 +1037,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -1066,7 +1067,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "arrow-schema", @@ -1866,7 +1867,7 @@ dependencies = [ [[package]] name = "parquet" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-array", diff --git a/Cargo.toml b/Cargo.toml index 5052eab71..ad86c08ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,26 +64,26 @@ serde_json = { version = "1.0.96" } [patch.crates-io] # datafusion: branch=v36-blaze -datafusion = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-common = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-execution = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-optimizer = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-physical-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} +datafusion = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-common = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-execution = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-optimizer = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-physical-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} # arrow: branch=v50-blaze -arrow = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-arith = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-array = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-buffer = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-cast = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-data = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-ord = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-row = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-schema = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-select = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-string = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -parquet = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} +arrow = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-arith = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-array = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-buffer = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-cast = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-data = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-ord = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-row = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-schema = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-select = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-string = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +parquet = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} # serde_json: branch=v1.0.96-blaze serde_json = { git = "https://github.com/blaze-init/json", branch = "v1.0.96-blaze" } diff --git a/README.md b/README.md index 672528fa2..eeb171e13 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ Blaze._ ```shell SHIM=spark333 # or spark303 -MODE=release # or dev +MODE=release # or pre mvn package -P"${SHIM}" -P"${MODE}" ``` @@ -94,11 +94,16 @@ This section describes how to submit and configure a Spark Job with Blaze suppor 1. move blaze jar package to spark client classpath (normally `spark-xx.xx.xx/jars/`). 2. add the follow confs to spark configuration in `spark-xx.xx.xx/conf/spark-default.conf`: + ```properties +spark.blaze.enable true spark.sql.extensions org.apache.spark.sql.blaze.BlazeSparkSessionExtension spark.shuffle.manager org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleManager +spark.memory.offHeap.enabled false -# other blaze confs defined in spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +# suggested executor memory configuration +spark.executor.memory 4g +spark.executor.memoryOverhead 4096 ``` 3. submit a query with spark-sql, or other tools like spark-thriftserver: @@ -108,16 +113,15 @@ spark-sql -f tpcds/q01.sql ## Performance -Check [Benchmark Results](./benchmark-results/20240202.md) with the latest date for the performance -comparison with vanilla Spark on TPC-DS 1TB dataset. The benchmark result shows that Blaze saved -~55% query time and ~60% cluster resources in average. ~6x performance achieved for the best case (q06). +Check [Benchmark Results](./benchmark-results/20240701-blaze300.md) with the latest date for the performance +comparison with vanilla Spark 3.3.3. The benchmark result shows that Blaze save about 50% time on TPC-DS/TPC-H 1TB datasets. Stay tuned and join us for more upcoming thrilling numbers. -Query time: -![20240202-query-time](./benchmark-results/blaze-query-time-comparison-20240202.png) +TPC-DS Query time: +![20240701-query-time-tpcds](./benchmark-results/spark333-vs-blaze300-query-time-20240701.png) -Cluster resources: -![20240202-resources](./benchmark-results/blaze-cluster-resources-cost-comparison-20240202.png) +TPC-H Query time: +![20240701-query-time-tpch](./benchmark-results/spark333-vs-blaze300-query-time-20240701-tpch.png) We also encourage you to benchmark Blaze and share the results with us. 🤗 diff --git a/RELEASES.md b/RELEASES.md index 795eeb7d9..551e0e674 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,12 +1,15 @@ -# blaze-v2.0.9.1 +# blaze-v3.0.0 ## Features -* Supports failing-back nondeterministic expressions. -* Supports "$[].xxx" jsonpath syntax in get_json_object(). +* Supports using spark.io.compression.codec for shuffle/broadcast compression +* Supports date type casting +* Refactor join implementations to support existence joins and BHJ building hash map on driver side ## Performance -* Supports adaptive batch size in ParquetScan, improving vectorized reading performance. -* Supports directly spill to disk file when on-heap memory is full. +* Fixed performance issues when running on spark3 with default configurations +* Use cached parquet metadata +* Refactor native broadcast to avoid duplicated broadcast jobs +* Supports spark333 batch shuffle reading ## Bugfix -* Fix incorrect parquet rowgroup pruning with files containing deprecated min/max values. +* Fix in_list conversion in from_proto.rs diff --git a/benchmark-results/20240202.md b/benchmark-results/20240202.md deleted file mode 100644 index 98771aea1..000000000 --- a/benchmark-results/20240202.md +++ /dev/null @@ -1,152 +0,0 @@ - -# Report 2024-02-02 - -### Versions -- Blaze version: [2.0.8](https://github.com/blaze-init/blaze/tree/v2.0.8) -- Vanilla spark version: spark-3.3.3 - -### Environment -Hadoop 2.6.0 cluster mode running on 4 nodes, See [Kwai server conf](./kwai1-hardware-conf.md). - -### Configuration - -- Blaze -```properties -spark.executor.memory 5g -spark.executor.memoryOverhead 3072 -spark.blaze.memoryFraction 0.7 -spark.blaze.enable.caseconvert.functions true -spark.blaze.enable.smjInequalityJoin false -spark.blaze.enable.bhjFallbacksToSmj false -``` - -- Vanilla Spark -```properties -spark.executor.memory 6g -spark.executor.memoryOverhead 2048 -``` - -- Common configurations -```properties -spark.speculation false -spark.sql.adaptive.coalescePartitions.initialPartitionNum 1000 -spark.sql.adaptive.coalescePartitions.minPartitionNum 20 -spark.sql.adaptive.coalescePartitions.minPartitionSize 65536 -spark.sql.files.maxPartitionBytes 268435456 -spark.sql.autoBroadcastJoinThreshold 20971520 -``` - -### Results -Query time comparison (seconds): -![blaze-query-time-comparison-20240202.png](blaze-query-time-comparison-20240202.png) - -Executor time comparison (Memory Bytes * Seconds): -![blaze-cluster-resources-cost-comparison-20240202.png](blaze-cluster-resources-cost-comparison-20240202.png) - -| | Memcost Spark | Memcost Blaze | Blaze / non-Blaze | | Query time Spark | Query time Blaze | Blaze / non-Blaze | -| ---- | ------------- | ------------- | ----------------- | ---- | ---------------- | ---------------- | ----------------- | -| q01 | 1064427008 | 449834240 | 0.422606939 | q01 | 10.6 | 6.6 | 0.622641509 | -| q02 | 2368745984 | 2721906176 | 1.149091627 | q02 | 13 | 49.4 | 3.8 | -| q03 | 2393231360 | 1122979200 | 0.469231358 | q03 | 12 | 7.4 | 0.616666667 | -| q04 | 70389727232 | 24298207232 | 0.34519536 | q04 | 292.8 | 105.3 | 0.359631148 | -| q05 | 14231480320 | 4769662976 | 0.33514876 | q05 | 58.5 | 22.4 | 0.382905983 | -| q06 | 14682045440 | 2202324480 | 0.150001203 | q06 | 72 | 11.9 | 0.165277778 | -| q07 | 5505946624 | 3115785216 | 0.565894555 | q07 | 27.3 | 17.6 | 0.644688645 | -| q08 | 2245357056 | 1041340864 | 0.463775176 | q08 | 12.3 | 7.5 | 0.609756098 | -| q09 | 12896677888 | 6773364224 | 0.525202248 | q09 | 53.6 | 28.8 | 0.537313433 | -| q10 | 3314154496 | 975653056 | 0.294389733 | q10 | 22.8 | 15.5 | 0.679824561 | -| q11 | 11654519808 | 12878838784 | 1.105051001 | q11 | 53.4 | 58.8 | 1.101123596 | -| q12 | 686692864 | 502959360 | 0.732437144 | q12 | 5 | 4.8 | 0.96 | -| q13 | 7461695488 | 3379008256 | 0.452847247 | q13 | 37.2 | 17.2 | 0.462365591 | -| q14a | 51038400512 | 18640779264 | 0.365230475 | q14a | 323 | 182 | 0.563467492 | -| q14b | 37148839936 | 16267428864 | 0.437898704 | q14b | 259.7 | 177.7 | 0.684251059 | -| q15 | 1749106432 | 1486633472 | 0.849938829 | q15 | 9.5 | 8.9 | 0.936842105 | -| q16 | 27584167936 | 13769887744 | 0.499195327 | q16 | 123.3 | 65.1 | 0.527980535 | -| q17 | 58074173440 | 14808196096 | 0.254987634 | q17 | 257.3 | 67.6 | 0.262728333 | -| q18 | 3536184832 | 2220624640 | 0.627971881 | q18 | 21.8 | 14.2 | 0.651376147 | -| q19 | 3174092032 | 1489778944 | 0.469355938 | q19 | 24.1 | 9.4 | 0.390041494 | -| q20 | 1256781568 | 897696896 | 0.714282353 | q20 | 8 | 5.6 | 0.7 | -| q21 | 8047820 | 2254438 | 0.280130271 | q21 | 5.4 | 2.1 | 0.388888889 | -| q22 | 12014388 | 5016781 | 0.417564424 | q22 | 7.4 | 3.2 | 0.432432432 | -| q23a | 81502380032 | 31732547584 | 0.389345042 | q23a | 357.9 | 149.3 | 0.41715563 | -| q23b | 80083304448 | 30878326784 | 0.385577581 | q23b | 351.1 | 136.8 | 0.389632583 | -| q24a | 49402916864 | 16241479680 | 0.328755481 | q24a | 215.1 | 79.8 | 0.370990237 | -| q24b | 49900470272 | 16226333696 | 0.325173963 | q24b | 215.9 | 79.4 | 0.367762853 | -| q25 | 78905024512 | 18997395456 | 0.24076281 | q25 | 352.4 | 84.6 | 0.240068104 | -| q26 | 2931277056 | 1529305728 | 0.521719953 | q26 | 16.5 | 10.9 | 0.660606061 | -| q27 | 5451323904 | 2849344768 | 0.522688583 | q27 | 28.1 | 16.4 | 0.583629893 | -| q28 | 21703208960 | 11386787840 | 0.524659181 | q28 | 88.8 | 49 | 0.551801802 | -| q29 | 54227959808 | 16019457024 | 0.295409547 | q29 | 242 | 74.6 | 0.308264463 | -| q30 | 302891104 | 135692304 | 0.447990391 | q30 | 8.6 | 5.4 | 0.627906977 | -| q31 | 9615978496 | 4488949248 | 0.466821889 | q31 | 41.2 | 23.3 | 0.565533981 | -| q32 | 2419902464 | 958353984 | 0.396030005 | q32 | 14.6 | 6.6 | 0.452054795 | -| q33 | 4866960896 | 2139079296 | 0.43951027 | q33 | 22.6 | 13.7 | 0.60619469 | -| q34 | 2799148544 | 1131107328 | 0.404089783 | q34 | 15.4 | 11.5 | 0.746753247 | -| q35 | 5119056896 | 1135299840 | 0.221779102 | q35 | 42.5 | 16.1 | 0.378823529 | -| q36 | 4177485312 | 2434962944 | 0.582877679 | q36 | 21.4 | 13.8 | 0.644859813 | -| q37 | 7496896512 | 932590400 | 0.124396862 | q37 | 49.1 | 21 | 0.427698574 | -| q38 | 4062855168 | 2151781376 | 0.529622959 | q38 | 30.4 | 17 | 0.559210526 | -| q39a | 17670146 | 6114509 | 0.346036133 | q39a | 5.5 | 2.5 | 0.454545455 | -| q39b | 15856440 | 6545408 | 0.412791774 | q39b | 5.3 | 2.5 | 0.471698113 | -| q40 | 15339921408 | 5915641344 | 0.385637005 | q40 | 75.2 | 31.3 | 0.416223404 | -| q41 | 429260 | 339148 | 0.790075945 | q41 | 0.7 | 0.8 | 1.142857143 | -| q42 | 2318457088 | 973158720 | 0.419744116 | q42 | 11.8 | 5.7 | 0.483050847 | -| q43 | 2643484416 | 1623958784 | 0.614325083 | q43 | 32.7 | 8.4 | 0.256880734 | -| q44 | 4751142912 | 2148898048 | 0.452290762 | q44 | 23.4 | 45.1 | 1.927350427 | -| q45 | 859250816 | 774846720 | 0.90177013 | q45 | 5.7 | 5.4 | 0.947368421 | -| q46 | 5104092160 | 2574011392 | 0.504303471 | q46 | 25.9 | 16.4 | 0.633204633 | -| q47 | 4227521024 | 3525534464 | 0.833948417 | q47 | 22.9 | 21.8 | 0.951965066 | -| q48 | 5522035200 | 2340343040 | 0.423818928 | q48 | 27.8 | 12.3 | 0.442446043 | -| q49 | 18826491904 | 7428071936 | 0.394554226 | q49 | 85.9 | 34.6 | 0.402793946 | -| q50 | 40762097664 | 7524097536 | 0.184585631 | q50 | 177.7 | 35.4 | 0.199212155 | -| q51 | 6946065920 | 2681179904 | 0.385999778 | q51 | 42.2 | 16.5 | 0.390995261 | -| q52 | 2423949056 | 989473792 | 0.408207338 | q52 | 13.3 | 6.3 | 0.473684211 | -| q53 | 2978816512 | 1240832768 | 0.416552266 | q53 | 14.8 | 8.2 | 0.554054054 | -| q54 | 18391373824 | 4565063680 | 0.248217655 | q54 | 111.6 | 22.2 | 0.198924731 | -| q55 | 3162129152 | 972509952 | 0.307549093 | q55 | 105.5 | 6.9 | 0.065402844 | -| q56 | 5093037568 | 2148585216 | 0.421867145 | q56 | 23.8 | 13.4 | 0.56302521 | -| q57 | 1928995200 | 1695657856 | 0.879036846 | q57 | 11.3 | 12.6 | 1.115044248 | -| q58 | 4623944704 | 3115723520 | 0.673823698 | q58 | 22.1 | 17.3 | 0.78280543 | -| q59 | 3398420992 | 3844507648 | 1.131262918 | q59 | 16.7 | 20.4 | 1.221556886 | -| q60 | 5130304512 | 2142691072 | 0.41765378 | q60 | 23.9 | 13.9 | 0.581589958 | -| q61 | 7130669568 | 3692626944 | 0.517851361 | q61 | 31.7 | 17.4 | 0.548895899 | -| q62 | 861110208 | 734616064 | 0.853103421 | q62 | 5.2 | 5.7 | 1.096153846 | -| q63 | 2913974784 | 1220212608 | 0.418745081 | q63 | 14.3 | 8.3 | 0.58041958 | -| q64 | 90821681152 | 45009911808 | 0.495585539 | q64 | 393 | 217.5 | 0.553435115 | -| q65 | 6929706496 | 4274123520 | 0.616782763 | q65 | 31 | 21 | 0.677419355 | -| q66 | 4168856064 | 2715433216 | 0.65136171 | q66 | 43 | 14.1 | 0.327906977 | -| q67 | 16813858816 | 11755253760 | 0.699140744 | q67 | 80.9 | 55.6 | 0.687268232 | -| q68 | 6089261056 | 3126788096 | 0.513492207 | q68 | 30.5 | 17.9 | 0.586885246 | -| q69 | 3271612928 | 955816128 | 0.292154405 | q69 | 20.3 | 10.7 | 0.527093596 | -| q70 | 5136711680 | 2733509120 | 0.532151557 | q70 | 24.8 | 14.6 | 0.588709677 | -| q71 | 5308280832 | 2298735360 | 0.433047051 | q71 | 24.9 | 17.5 | 0.702811245 | -| q72 | 34922856448 | 25395292160 | 0.72718256 | q72 | 174.4 | 131 | 0.751146789 | -| q73 | 2774287872 | 1024954944 | 0.369447942 | q73 | 14.1 | 9.1 | 0.645390071 | -| q74 | 8379421696 | 6975468544 | 0.832452262 | q74 | 38.8 | 33.8 | 0.871134021 | -| q75 | 21823078400 | 10669447168 | 0.488906605 | q75 | 98.5 | 57.9 | 0.587817259 | -| q76 | 5128589824 | 1943261056 | 0.378907482 | q76 | 36.8 | 11.7 | 0.317934783 | -| q77 | 6024270336 | 2879578880 | 0.477996292 | q77 | 33.3 | 15 | 0.45045045 | -| q78 | 90786725888 | 36193472512 | 0.398664807 | q78 | 377 | 154.4 | 0.409549072 | -| q79 | 5194788864 | 2371766016 | 0.456566393 | q79 | 25.4 | 14.4 | 0.566929134 | -| q80 | 86372556800 | 32238481408 | 0.373249127 | q80 | 351.9 | 133.5 | 0.379369139 | -| q81 | 537159104 | 260399136 | 0.484770963 | q81 | 9.4 | 6.2 | 0.659574468 | -| q82 | 14299611136 | 1773495680 | 0.12402405 | q82 | 83.6 | 22.9 | 0.273923445 | -| q83 | 370437344 | 190986624 | 0.515570655 | q83 | 13.5 | 8.7 | 0.644444444 | -| q84 | 719424896 | 87703536 | 0.121907841 | q84 | 12.9 | 2.7 | 0.209302326 | -| q85 | 2387102720 | 839722560 | 0.351774791 | q85 | 28.9 | 11.6 | 0.401384083 | -| q86 | 810701824 | 355365728 | 0.438343319 | q86 | 35.8 | 4.7 | 0.131284916 | -| q87 | 4160053504 | 2180360192 | 0.524118305 | q87 | 30.6 | 17 | 0.555555556 | -| q88 | 14849074176 | 5351185920 | 0.360371688 | q88 | 61.4 | 24.8 | 0.403908795 | -| q89 | 3119070464 | 1416152448 | 0.454030284 | q89 | 15.7 | 10.2 | 0.649681529 | -| q90 | 947142592 | 274404096 | 0.28971783 | q90 | 7.5 | 2.7 | 0.36 | -| q91 | 140899136 | 61533388 | 0.436719413 | q91 | 5.3 | 2.9 | 0.547169811 | -| q92 | 1332381184 | 475557024 | 0.356922651 | q92 | 9.2 | 4.1 | 0.445652174 | -| q93 | 45288034304 | 12807971840 | 0.282811388 | q93 | 200.3 | 61.8 | 0.308537194 | -| q94 | 16076772352 | 6874429440 | 0.427600098 | q94 | 82.9 | 36.3 | 0.43787696 | -| q95 | 23590234112 | 17770180608 | 0.753285471 | q95 | 120.8 | 103.8 | 0.859271523 | -| q96 | 1928228480 | 637357376 | 0.330540381 | q96 | 9.8 | 4.1 | 0.418367347 | -| q97 | 10208055296 | 3761395456 | 0.368473264 | q97 | 56.9 | 26.3 | 0.462214411 | -| q98 | 2648099584 | 1812735104 | 0.684541894 | q98 | 15.4 | 11.8 | 0.766233766 | -| q99 | 1629526400 | 1538929536 | 0.944402948 | q99 | 8.7 | 9 | 1.034482759 | -| | | | | | | | | -| sum | 1.50781E+12 | 6.1011E+11 | | | 7367.7 | 3390.8 | | diff --git a/benchmark-results/20240701-blaze300.md b/benchmark-results/20240701-blaze300.md new file mode 100644 index 000000000..1106d5451 --- /dev/null +++ b/benchmark-results/20240701-blaze300.md @@ -0,0 +1,201 @@ + +# Report 2024-07-01 + +### Versions +- Blaze version: [3.0.0](https://github.com/blaze-init/blaze/tree/v3.0.0) +- Vanilla spark version: spark-3.3.3 opensource version + +### Environment +Hadoop 2.10.2 cluster mode running on 7 nodes, See [Kwai server conf](./kwai1-hardware-conf.md). + +### Configuration + +- Blaze +```properties +spark.blaze.enable true +spark.sql.extensions org.apache.spark.sql.blaze.BlazeSparkSessionExtension +spark.shuffle.manager org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleManager +spark.memory.offHeap.enabled false + +# suggested executor memory configuration +spark.executor.memory 4g +spark.executor.memoryOverhead 4096 +``` + +- Vanilla Spark +```properties +spark.executor.memory 6g +spark.executor.memoryOverhead 2048 +``` + +- Common configurations +```properties +spark.master yarn +spark.yarn.stagingDir.list hdfs://blaze-test/home/spark/user/ + +spark.eventLog.enabled true +spark.eventLog.dir hdfs://blaze-test/home/yarn/spark-eventlog +spark.history.fs.logDirectory hdfs://blaze-test/home/yarn/spark-eventlog + +spark.externalBlockStore.url.list hdfs://blaze-test/home/platform +spark.driver.extraJavaOptions -XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/media/disk1/spark/ -Djava.io.tmpdir=/media/disk1/tmp -Dlog4j2.formatMsgNoLookups=true +spark.local.dir /media/disk1/spark/localdir + +spark.shuffle.service.enabled true +spark.shuffle.service.port 7337 + +spark.driver.memory 20g +spark.driver.memoryOverhead 4096 + +spark.executor.instances 10000 +spark.dynamicallocation.maxExecutors 10000 +spark.executor.cores 5 + +spark.io.compression.codec zstd + +# spark3.3+ disable char/varchar padding +spark.sql.readSideCharPadding false +``` + +### TPC-DS Results +Blaze saved 46% total query time comparing to spark, benchmarks using the above configuration. +Query time comparison (seconds): +![spark333-vs-blaze300-query-time-20240701.png](spark333-vs-blaze300-query-time-20240701.png) + +| | Blaze | Spark | Speedup(x) | +| ------ | -------- | -------- | ---------- | +| q1 | 8.946 | 15.073 | 1.68 | +| q2 | 14.558 | 10.482 | 0.72 | +| q3 | 7.892 | 8.239 | 1.04 | +| q4 | 192.499 | 357.818 | 1.86 | +| q5 | 15.943 | 37.471 | 2.35 | +| q6 | 20.139 | 41.034 | 2.04 | +| q7 | 11.274 | 16.532 | 1.47 | +| q8 | 5.922 | 8.885 | 1.50 | +| q9 | 16.797 | 18.52 | 1.10 | +| q10 | 8.908 | 15.634 | 1.76 | +| q11 | 114.716 | 108.502 | 0.95 | +| q12 | 5.923 | 5.642 | 0.95 | +| q13 | 13.641 | 22.386 | 1.64 | +| q14a | 74.169 | 162.795 | 2.19 | +| q14b | 71.721 | 187.395 | 2.61 | +| q15 | 27.634 | 57.961 | 2.10 | +| q16 | 51.884 | 61.577 | 1.19 | +| q17 | 51.08 | 146.804 | 2.87 | +| q18 | 10.858 | 33.96 | 3.13 | +| q19 | 7.726 | 13.939 | 1.80 | +| q20 | 7.926 | 8.858 | 1.12 | +| q21 | 2.456 | 4.68 | 1.91 | +| q22 | 13.092 | 21.736 | 1.66 | +| q23a | 170.957 | 418.987 | 2.45 | +| q23b | 236.542 | 528.576 | 2.23 | +| q24a | 68.542 | 164.819 | 2.40 | +| q24b | 69.916 | 156.236 | 2.23 | +| q25 | 72.066 | 166.081 | 2.30 | +| q26 | 6.899 | 10.715 | 1.55 | +| q27 | 10.116 | 15.436 | 1.53 | +| q28 | 22.418 | 33.152 | 1.48 | +| q29 | 60.618 | 146.133 | 2.41 | +| q30 | | | #DIV/0! | +| q31 | 17.947 | 34.14 | 1.90 | +| q32 | 1.13 | 1.207 | 1.07 | +| q33 | 13.376 | 14.912 | 1.11 | +| q34 | 8.123 | 13.009 | 1.60 | +| q35 | 9.667 | 23.604 | 2.44 | +| q36 | 11.766 | 14.016 | 1.19 | +| q37 | 6.912 | 15.854 | 2.29 | +| q38 | 14.037 | 22.247 | 1.58 | +| q39a | 23.506 | 14.385 | 0.61 | +| q39b | 15.658 | 14.812 | 0.95 | +| q40 | 21.323 | 53.26 | 2.50 | +| q41 | 1.636 | 4.159 | 2.54 | +| q42 | 4.386 | 9.169 | 2.09 | +| q43 | 6.184 | 7.436 | 1.20 | +| q44 | 7.177 | 14.616 | 2.04 | +| q45 | 13.453 | 42.177 | 3.14 | +| q46 | 11.486 | 19.182 | 1.67 | +| q47 | 30.546 | 22.316 | 0.73 | +| q48 | 14.617 | 21.361 | 1.46 | +| q49 | 23.28 | 50.818 | 2.18 | +| q50 | 31.91 | 78.861 | 2.47 | +| q51 | 14.767 | 21.594 | 1.46 | +| q52 | 4.402 | 7.476 | 1.70 | +| q53 | 5.893 | 8.728 | 1.48 | +| q54 | 25.024 | 50.243 | 2.01 | +| q55 | 4.607 | 6.456 | 1.40 | +| q56 | 12.508 | 15.227 | 1.22 | +| q57 | 13.066 | 11.239 | 0.86 | +| q58 | 19.706 | 13.923 | 0.71 | +| q59 | 54.67 | 22.584 | 0.41 | +| q60 | 12.192 | 14.538 | 1.19 | +| q61 | 17.277 | 24.019 | 1.39 | +| q62 | 6.786 | 4.75 | 0.70 | +| q63 | 5.849 | 8.618 | 1.47 | +| q64 | 121.374 | 292.479 | 2.41 | +| q65 | 22.164 | 35.938 | 1.62 | +| q66 | 11.608 | 11.713 | 1.01 | +| q67 | 250.606 | 617.991 | 2.47 | +| q68 | 11.785 | 20.809 | 1.77 | +| q69 | 8.581 | 15.859 | 1.85 | +| q70 | 13.819 | 13.642 | 0.99 | +| q71 | 11.371 | 15.644 | 1.38 | +| q72 | 260.602 | 331.28 | 1.27 | +| q73 | 6.464 | 10.715 | 1.66 | +| q74 | 75.604 | 99.419 | 1.31 | +| q75 | 32.963 | 47.765 | 1.45 | +| q76 | 8.876 | 14.479 | 1.63 | +| q77 | 10.093 | 11.927 | 1.18 | +| q78 | 115.239 | 208.522 | 1.81 | +| q79 | 10.867 | 17.617 | 1.62 | +| q80 | 98.788 | 179.521 | 1.82 | +| q81 | 5.629 | 9.121 | 1.62 | +| q82 | 10.638 | 21.232 | 2.00 | +| q83 | 4.092 | 4.624 | 1.13 | +| q84 | 5.152 | 4.547 | 0.88 | +| q85 | 21.524 | 12.985 | 0.60 | +| q86 | 5.256 | 6.288 | 1.20 | +| q87 | 13.981 | 31.153 | 2.23 | +| q88 | 24.139 | 29.653 | 1.23 | +| q89 | 8.124 | 10.193 | 1.25 | +| q90 | 3.507 | 3.555 | 1.01 | +| q91 | 3.121 | 3.695 | 1.18 | +| q92 | 3.525 | 5.06 | 1.44 | +| q93 | 37.673 | 106.785 | 2.83 | +| q94 | 27.961 | 51.743 | 1.85 | +| q95 | 60.55 | 87.68 | 1.45 | +| q96 | 4.003 | 11.587 | 2.89 | +| q97 | 16.22 | 36.13 | 2.23 | +| q98 | 16.228 | 17.334 | 1.07 | +| q99 | 8.363 | 6.842 | 0.82 | +| total: | 3309.135 | 6112.521 | 1.85 | + +### TPC-H Results +Blaze saved 55% total query time comparing to spark, benchmarks using the above configuration. +Query time comparison (seconds): +![spark333-vs-blaze300-query-time-20240701-tpch.png](spark333-vs-blaze300-query-time-20240701-tpch.png) + +| | Blaze | Spark | Speedup(x) | +| ------ | ------- | -------- | ---------- | +| q01 | 18.436 | 38.834 | 2.11 | +| q02 | 19.276 | 34.415 | 1.79 | +| q03 | 38.373 | 85.78 | 2.24 | +| q04 | 22.427 | 69.202 | 3.09 | +| q05 | 68.087 | 126.88 | 1.86 | +| q06 | 8.945 | 25.513 | 2.85 | +| q07 | 87.404 | 206.581 | 2.36 | +| q08 | 79.142 | 164.408 | 2.08 | +| q09 | 107.604 | 237.855 | 2.21 | +| q10 | 26.112 | 98.856 | 3.79 | +| q11 | 15.511 | 59.842 | 3.86 | +| q12 | 20.874 | 52.742 | 2.53 | +| q13 | 21.404 | 72.36 | 3.38 | +| q14 | 11.752 | 33.146 | 2.82 | +| q15 | 20.925 | 62.572 | 2.99 | +| q16 | 9.378 | 24.72 | 2.64 | +| q17 | 70.05 | 137.646 | 1.96 | +| q18 | 153.348 | 215.998 | 1.41 | +| q19 | 11.616 | 24.857 | 2.14 | +| q20 | 22.572 | 80.441 | 3.56 | +| q21 | 132.751 | 291.799 | 2.20 | +| q22 | 11.561 | 32.92 | 2.85 | +| total: | 977.548 | 2177.367 | 2.23 | diff --git a/benchmark-results/blaze-cluster-resources-cost-comparison-20240202.png b/benchmark-results/blaze-cluster-resources-cost-comparison-20240202.png deleted file mode 100644 index d88fb04cb..000000000 Binary files a/benchmark-results/blaze-cluster-resources-cost-comparison-20240202.png and /dev/null differ diff --git a/benchmark-results/blaze-query-time-comparison-20240202.png b/benchmark-results/blaze-query-time-comparison-20240202.png deleted file mode 100644 index 14d8db559..000000000 Binary files a/benchmark-results/blaze-query-time-comparison-20240202.png and /dev/null differ diff --git a/benchmark-results/kwai1-hardware-conf.md b/benchmark-results/kwai1-hardware-conf.md index 8cab773f3..180cf4d93 100644 --- a/benchmark-results/kwai1-hardware-conf.md +++ b/benchmark-results/kwai1-hardware-conf.md @@ -5,296 +5,30 @@ Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian -CPU(s): 128 -On-line CPU(s) list: 0-127 +CPU(s): 56 +On-line CPU(s) list: 0-55 Thread(s) per core: 2 -Core(s) per socket: 64 -Socket(s): 1 -NUMA node(s): 1 -Vendor ID: AuthenticAMD -CPU family: 25 -Model: 1 -Model name: AMD EPYC 7713 64-Core Processor +Core(s) per socket: 14 +Socket(s): 2 +NUMA node(s): 2 +Vendor ID: GenuineIntel +CPU family: 6 +Model: 79 +Model name: Intel(R) Xeon(R) CPU E5-2680 v4 @ 2.40GHz Stepping: 1 -CPU MHz: 2697.144 -CPU max MHz: 2000.0000 -CPU min MHz: 1500.0000 -BogoMIPS: 3992.42 -Virtualization: AMD-V +CPU MHz: 2900.810 +CPU max MHz: 3300.0000 +CPU min MHz: 1200.0000 +BogoMIPS: 4800.46 +Virtualization: VT-x L1d cache: 32K L1i cache: 32K -L2 cache: 512K -L3 cache: 32768K -NUMA node0 CPU(s): 0-127 -Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca +L2 cache: 256K +L3 cache: 35840K +NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54 +NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55 +Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts ``` -### 2. fdisk -l - - +### 2. disk All HDD disks. - -``` -Disk /dev/sda: 240.1 GB, 240057409536 bytes, 468862128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 512 bytes -I/O size (minimum/optimal): 512 bytes / 512 bytes -Disk label type: dos -Disk identifier: 0xc67f9894 - -Device Boot Start End Blocks Id System -/dev/sda1 * 2048 976895 487424 83 Linux -/dev/sda2 976896 199219199 99121152 83 Linux -/dev/sda3 199219200 468860927 134820864 83 Linux - -Disk /dev/sdd: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sde: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdf: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdg: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdh: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdi: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdj: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdk: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdl: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdm: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdn: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdo: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdp: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdq: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdc: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdt: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdu: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdb: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdv: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdw: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdx: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdr: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdz: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sds: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdac: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdad: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdy: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdaf: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdag: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdah: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdae: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdai: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdab: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdaj: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdak: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdal: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdam: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdan: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdao: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdap: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdaq: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes - - -Disk /dev/sdaa: 18000.2 GB, 18000207937536 bytes, 35156656128 sectors -Units = sectors of 1 * 512 = 512 bytes -Sector size (logical/physical): 512 bytes / 4096 bytes -I/O size (minimum/optimal): 4096 bytes / 4096 bytes -``` diff --git a/benchmark-results/spark333-vs-blaze300-query-time-20240701-tpch.png b/benchmark-results/spark333-vs-blaze300-query-time-20240701-tpch.png new file mode 100644 index 000000000..ca82b23e8 Binary files /dev/null and b/benchmark-results/spark333-vs-blaze300-query-time-20240701-tpch.png differ diff --git a/benchmark-results/spark333-vs-blaze300-query-time-20240701.png b/benchmark-results/spark333-vs-blaze300-query-time-20240701.png new file mode 100644 index 000000000..71496e57b Binary files /dev/null and b/benchmark-results/spark333-vs-blaze300-query-time-20240701.png differ diff --git a/native-engine/blaze-jni-bridge/src/conf.rs b/native-engine/blaze-jni-bridge/src/conf.rs index dd476ed64..9eccc0e5d 100644 --- a/native-engine/blaze-jni-bridge/src/conf.rs +++ b/native-engine/blaze-jni-bridge/src/conf.rs @@ -41,6 +41,8 @@ define_conf!(BooleanConf, IGNORE_CORRUPTED_FILES); define_conf!(BooleanConf, PARTIAL_AGG_SKIPPING_ENABLE); define_conf!(DoubleConf, PARTIAL_AGG_SKIPPING_RATIO); define_conf!(IntConf, PARTIAL_AGG_SKIPPING_MIN_ROWS); +define_conf!(BooleanConf, PARQUET_ENABLE_PAGE_FILTERING); +define_conf!(BooleanConf, PARQUET_ENABLE_BLOOM_FILTER); pub trait BooleanConf { fn key(&self) -> &'static str; diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 9818a136d..b424f1fcf 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -35,19 +35,19 @@ message PhysicalPlanNode { FilterExecNode filter = 8; UnionExecNode union = 9; SortMergeJoinExecNode sort_merge_join = 10; - BroadcastJoinExecNode broadcast_join = 11; - RenameColumnsExecNode rename_columns = 12; - EmptyPartitionsExecNode empty_partitions = 13; - AggExecNode agg = 14; - LimitExecNode limit = 15; - FFIReaderExecNode ffi_reader = 16; - CoalesceBatchesExecNode coalesce_batches = 17; - ExpandExecNode expand = 18; - RssShuffleWriterExecNode rss_shuffle_writer= 19; - WindowExecNode window = 20; - GenerateExecNode generate = 21; - ParquetSinkExecNode parquet_sink = 22; - BroadcastNestedLoopJoinExecNode broadcast_nested_loop_join = 23; + BroadcastJoinBuildHashMapExecNode broadcast_join_build_hash_map = 11; + BroadcastJoinExecNode broadcast_join = 12; + RenameColumnsExecNode rename_columns = 13; + EmptyPartitionsExecNode empty_partitions = 14; + AggExecNode agg = 15; + LimitExecNode limit = 16; + FFIReaderExecNode ffi_reader = 17; + CoalesceBatchesExecNode coalesce_batches = 18; + ExpandExecNode expand = 19; + RssShuffleWriterExecNode rss_shuffle_writer= 20; + WindowExecNode window = 21; + GenerateExecNode generate = 22; + ParquetSinkExecNode parquet_sink = 23; } } @@ -398,27 +398,28 @@ enum PartitionMode { } message SortMergeJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - repeated JoinOn on = 3; - repeated SortOptions sort_options = 4; - JoinType join_type = 5; - JoinFilter join_filter = 6; + Schema schema = 1; + PhysicalPlanNode left = 2; + PhysicalPlanNode right = 3; + repeated JoinOn on = 4; + repeated SortOptions sort_options = 5; + JoinType join_type = 6; + JoinFilter join_filter = 7; } -message BroadcastJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - repeated JoinOn on = 3; - JoinType join_type = 4; - JoinFilter join_filter = 5; +message BroadcastJoinBuildHashMapExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalExprNode keys =2; } -message BroadcastNestedLoopJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - JoinType join_type = 3; - JoinFilter join_filter = 4; +message BroadcastJoinExecNode { + Schema schema = 1; + PhysicalPlanNode left = 2; + PhysicalPlanNode right = 3; + repeated JoinOn on = 4; + JoinType join_type = 5; + JoinSide broadcast_side = 6; + string cached_build_hash_map_id = 7; } message RenameColumnsExecNode { @@ -438,6 +439,7 @@ enum JoinType { FULL = 3; SEMI = 4; ANTI = 5; + EXISTENCE = 6; } message SortOptions { @@ -456,8 +458,8 @@ message BoundReference { } message JoinOn { - PhysicalColumn left = 1; - PhysicalColumn right = 2; + PhysicalExprNode left = 1; + PhysicalExprNode right = 2; } message ProjectionExecNode { diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 1f4e82425..cc89de0a5 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -45,7 +45,6 @@ use datafusion::{ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, PhysicalSortExpr, }, - joins::utils::{ColumnIndex, JoinFilter}, union::UnionExec, ColumnStatistics, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }, @@ -61,8 +60,8 @@ use datafusion_ext_exprs::{ use datafusion_ext_plans::{ agg::{create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr}, agg_exec::AggExec, + broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, broadcast_join_exec::BroadcastJoinExec, - broadcast_nested_loop_join_exec::BroadcastNestedLoopJoinExec, debug_exec::DebugExec, empty_partitions_exec::EmptyPartitionsExec, expand_exec::ExpandExec, @@ -89,7 +88,7 @@ use object_store::{path::Path, ObjectMeta}; use crate::{ convert_box_required, convert_required, error::PlanSerDeError, - from_proto_binary_op, into_required, proto_error, protobuf, + from_proto_binary_op, proto_error, protobuf, protobuf::{ physical_expr_node::ExprType, physical_plan_node::PhysicalPlanType, GenerateFunction, }, @@ -182,19 +181,20 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ))) } PhysicalPlanType::SortMergeJoin(sort_merge_join) => { + let schema = Arc::new(convert_required!(sort_merge_join.schema)?); let left: Arc = convert_box_required!(sort_merge_join.left)?; let right: Arc = convert_box_required!(sort_merge_join.right)?; let on: Vec<(Arc, Arc)> = sort_merge_join .on .iter() .map(|col| { - let left_col: Column = into_required!(col.left)?; - let left_col_binded: Arc = - Arc::new(Column::new_with_schema(left_col.name(), &left.schema())?); - let right_col: Column = into_required!(col.right)?; - let right_col_binded: Arc = - Arc::new(Column::new_with_schema(right_col.name(), &right.schema())?); - Ok((left_col_binded, right_col_binded)) + let left_key = + try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let left_key_binded = bind(left_key, &left.schema())?; + let right_key = + try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let right_key_binded = bind(right_key, &right.schema())?; + Ok((left_key_binded, right_key_binded)) }) .collect::>()?; @@ -210,38 +210,14 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let join_type = protobuf::JoinType::try_from(sort_merge_join.join_type) .expect("invalid JoinType"); - let join_filter = sort_merge_join - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::, PlanSerDeError>>()?; - - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; Ok(Arc::new(SortMergeJoinExec::try_new( + schema, left, right, on, - join_type.into(), - join_filter, + join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, sort_options, )?)) } @@ -306,7 +282,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { self )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -342,97 +318,58 @@ impl TryInto> for &protobuf::PhysicalPlanNode { sort.fetch_limit.as_ref().map(|limit| limit.limit as usize), ))) } + PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => { + let input: Arc = convert_box_required!(bhm.input)?; + let keys = bhm + .keys + .iter() + .map(|expr| { + Ok(bind( + try_parse_physical_expr(expr, &input.schema())?, + &input.schema(), + )?) + }) + .collect::>, Self::Error>>()?; + Ok(Arc::new(BroadcastJoinBuildHashMapExec::new(input, keys))) + } PhysicalPlanType::BroadcastJoin(broadcast_join) => { + let schema = Arc::new(convert_required!(broadcast_join.schema)?); let left: Arc = convert_box_required!(broadcast_join.left)?; let right: Arc = convert_box_required!(broadcast_join.right)?; let on: Vec<(Arc, Arc)> = broadcast_join .on .iter() .map(|col| { - let left_col: Column = into_required!(col.left)?; - let left_col_binded: Arc = - Arc::new(Column::new_with_schema(left_col.name(), &left.schema())?); - let right_col: Column = into_required!(col.right)?; - let right_col_binded: Arc = - Arc::new(Column::new_with_schema(right_col.name(), &right.schema())?); - Ok((left_col_binded, right_col_binded)) + let left_key = + try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let left_key_binded = bind(left_key, &left.schema())?; + let right_key = + try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let right_key_binded = bind(right_key, &right.schema())?; + Ok((left_key_binded, right_key_binded)) }) .collect::>()?; let join_type = protobuf::JoinType::try_from(broadcast_join.join_type) .expect("invalid JoinType"); - let join_filter = broadcast_join - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::, PlanSerDeError>>()?; - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; + let broadcast_side = protobuf::JoinSide::try_from(broadcast_join.broadcast_side) + .expect("invalid BroadcastSide"); + + let cached_build_hash_map_id = broadcast_join.cached_build_hash_map_id.clone(); Ok(Arc::new(BroadcastJoinExec::try_new( + schema, left, right, on, - join_type.into(), - join_filter, - )?)) - } - PhysicalPlanType::BroadcastNestedLoopJoin(bnlj) => { - let left: Arc = convert_box_required!(bnlj.left)?; - let right: Arc = convert_box_required!(bnlj.right)?; - let join_type = - protobuf::JoinType::try_from(bnlj.join_type).expect("invalid JoinType"); - let join_filter = bnlj - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::, PlanSerDeError>>()?; - - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; - - Ok(Arc::new(BroadcastNestedLoopJoinExec::try_new( - left, - right, - join_type.into(), - join_filter, + join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, + broadcast_side + .try_into() + .map_err(|_| proto_error("invalid BroadcastSide"))?, + Some(cached_build_hash_map_id), )?)) } PhysicalPlanType::Union(union) => { diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index 30bd4c282..56cd4a6bf 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs @@ -15,10 +15,8 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit}; -use datafusion::{ - common::JoinSide, logical_expr::Operator, prelude::JoinType, scalar::ScalarValue, -}; -use datafusion_ext_plans::agg::AggFunction; +use datafusion::{common::JoinSide, logical_expr::Operator, scalar::ScalarValue}; +use datafusion_ext_plans::{agg::AggFunction, joins::join_utils::JoinType}; use crate::error::PlanSerDeError; @@ -111,6 +109,7 @@ impl From for JoinType { protobuf::JoinType::Full => JoinType::Full, protobuf::JoinType::Semi => JoinType::LeftSemi, protobuf::JoinType::Anti => JoinType::LeftAnti, + protobuf::JoinType::Existence => JoinType::Existence, } } } diff --git a/native-engine/datafusion-ext-commons/src/lib.rs b/native-engine/datafusion-ext-commons/src/lib.rs index 72f622354..ece6438af 100644 --- a/native-engine/datafusion-ext-commons/src/lib.rs +++ b/native-engine/datafusion-ext-commons/src/lib.rs @@ -13,7 +13,6 @@ // limitations under the License. #![feature(new_uninit)] -#![feature(io_error_other)] #![feature(slice_swap_unchecked)] #![feature(vec_into_raw_parts)] @@ -85,9 +84,9 @@ pub fn batch_size() -> usize { batch_size } -// for better cache usage +// bigger for better radix sort performance pub const fn staging_mem_size_for_partial_sort() -> usize { - 4194304 * 8 / 10 + 8388608 } // use bigger batch memory size writing shuffling data diff --git a/native-engine/datafusion-ext-commons/src/spark_hash.rs b/native-engine/datafusion-ext-commons/src/spark_hash.rs index 6a76bb953..85dac3077 100644 --- a/native-engine/datafusion-ext-commons/src/spark_hash.rs +++ b/native-engine/datafusion-ext-commons/src/spark_hash.rs @@ -77,10 +77,8 @@ fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { // avoid boundary checking in performance critical codes. // all operations are garenteed to be safe unsafe { - let mut h1 = hash_bytes_by_int( - std::slice::from_raw_parts(data.get_unchecked(0), len_aligned), - seed, - ); + let mut h1 = + hash_bytes_by_int(std::slice::from_raw_parts(data.as_ptr(), len_aligned), seed); for i in len_aligned..len { let half_word = *data.get_unchecked(i) as i8 as i32; diff --git a/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs b/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs index 966b2f668..ede47a407 100644 --- a/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs +++ b/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs @@ -194,8 +194,8 @@ enum ParsedJsonValue { #[derive(Debug)] enum HiveGetJsonObjectError { - InvalidJsonPath(String), - InvalidInput(String), + InvalidJsonPath, + InvalidInput, } struct HiveGetJsonObjectEvaluator { @@ -212,15 +212,11 @@ impl HiveGetJsonObjectEvaluator { evaluator.matchers.push(matcher); } if evaluator.matchers.first() != Some(&HiveGetJsonObjectMatcher::Root) { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "json path missing root".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } evaluator.matchers.remove(0); // remove root matcher if evaluator.matchers.contains(&HiveGetJsonObjectMatcher::Root) { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "json path has more than one root".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } Ok(evaluator) } @@ -240,9 +236,7 @@ impl HiveGetJsonObjectEvaluator { return Ok(v); } } - Err(HiveGetJsonObjectError::InvalidInput( - "invalid json string".to_string(), - )) + Err(HiveGetJsonObjectError::InvalidInput) } fn evaluate_with_value_serde_json( @@ -296,7 +290,7 @@ fn serde_json_value_to_string( serde_json::Value::Bool(b) => Ok(Some(b.to_string())), serde_json::Value::Array(_) | serde_json::Value::Object(_) => serde_json::to_string(value) .map(Some) - .map_err(|_| HiveGetJsonObjectError::InvalidInput("array to json error".to_string())), + .map_err(|_| HiveGetJsonObjectError::InvalidInput), } } @@ -310,7 +304,7 @@ fn sonic_value_to_string( sonic_rs::JsonType::Boolean => Ok(value.as_bool().map(|v| v.to_string())), _ => sonic_rs::to_string(value) .map(Some) - .map_err(|_| HiveGetJsonObjectError::InvalidInput("array to json error".to_string())), + .map_err(|_| HiveGetJsonObjectError::InvalidInput), } } @@ -352,9 +346,7 @@ impl HiveGetJsonObjectMatcher { } } if child_name.is_empty() { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "empty child name".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } Ok(Some(Self::Child(child_name))) } @@ -372,24 +364,18 @@ impl HiveGetJsonObjectMatcher { chars.next(); } None => { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "unterminated subscript".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } } } if index_str.is_empty() || index_str == "*" { return Ok(Some(Self::SubscriptAll)); } - let index = str::parse::(&index_str).map_err(|_| { - HiveGetJsonObjectError::InvalidJsonPath("invalid subscript index".to_string()) - })?; + let index = str::parse::(&index_str) + .map_err(|_| HiveGetJsonObjectError::InvalidJsonPath)?; Ok(Some(Self::Subscript(index))) } - Some(c) => Err(HiveGetJsonObjectError::InvalidJsonPath(format!( - "unexpected char in json path: {}", - c - ))), + Some(_) => Err(HiveGetJsonObjectError::InvalidJsonPath), } } diff --git a/native-engine/datafusion-ext-functions/src/spark_null_if.rs b/native-engine/datafusion-ext-functions/src/spark_null_if.rs index af753d5fa..4845a93b8 100644 --- a/native-engine/datafusion-ext-functions/src/spark_null_if.rs +++ b/native-engine/datafusion-ext-functions/src/spark_null_if.rs @@ -16,10 +16,7 @@ use std::sync::Arc; use arrow::{ array::*, - compute::{ - kernels::{cmp::eq, nullif::nullif}, - *, - }, + compute::kernels::{cmp::eq, nullif::nullif}, datatypes::*, }; use datafusion::{ @@ -87,7 +84,8 @@ pub fn spark_null_if_zero(args: &[ColumnarValue]) -> Result { ($dt:ident) => {{ type T = paste::paste! {arrow::datatypes::[<$dt Type>]}; let array = as_primitive_array::(array); - let eq_zeros = eq_scalar(array, T::default_value())?; + let _0 = PrimitiveArray::::new_scalar(Default::default()); + let eq_zeros = eq(array, &_0)?; Arc::new(nullif(array, &eq_zeros)?) as ArrayRef }}; } diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 6eb5d5e94..6deaa7c8d 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -223,7 +223,9 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { None => return Ok(Arg::Ignore), } } - if let ScalarValue::List(l) = scalar && l.data_type() == &DataType::Utf8 { + if let ScalarValue::List(l) = scalar + && l.data_type() == &DataType::Utf8 + { if l.is_null(0) { return Ok(Arg::Ignore); } diff --git a/native-engine/datafusion-ext-plans/Cargo.toml b/native-engine/datafusion-ext-plans/Cargo.toml index c8412fd41..ce9be4a3c 100644 --- a/native-engine/datafusion-ext-plans/Cargo.toml +++ b/native-engine/datafusion-ext-plans/Cargo.toml @@ -11,6 +11,7 @@ default = ["tokio/rt-multi-thread"] arrow = { workspace = true } async-trait = "0.1.80" base64 = "0.22.1" +bitvec = "1.0.1" byteorder = "1.5.0" bytes = "1.6.0" blaze-jni-bridge = { workspace = true } diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs new file mode 100644 index 000000000..3f1ca6d65 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs @@ -0,0 +1,150 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{compute::concat_batches, datatypes::SchemaRef}; +use datafusion::{ + common::Result, + execution::{SendableRecordBatchStream, TaskContext}, + physical_expr::{Partitioning, PhysicalExpr, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, + }, +}; +use futures::{stream::once, TryStreamExt}; + +use crate::{ + common::output::{NextBatchWithTimer, TaskOutputter}, + joins::join_hash_map::{join_hash_map_schema, JoinHashMap}, +}; + +pub struct BroadcastJoinBuildHashMapExec { + input: Arc, + keys: Vec>, + metrics: ExecutionPlanMetricsSet, +} + +impl BroadcastJoinBuildHashMapExec { + pub fn new(input: Arc, keys: Vec>) -> Self { + Self { + input, + keys, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl Debug for BroadcastJoinBuildHashMapExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "BroadcastJoinBuildHashMap [{:?}]", self.keys) + } +} + +impl DisplayAs for BroadcastJoinBuildHashMapExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "BroadcastJoinBuildHashMapExec [{:?}]", self.keys) + } +} + +impl ExecutionPlan for BroadcastJoinBuildHashMapExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + join_hash_map_schema(&self.input.schema()) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.input.output_partitioning().partition_count()) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new(children[0].clone(), self.keys.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input = self.input.execute(partition, context.clone())?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + once(execute_build_hash_map( + context, + input, + self.keys.clone(), + baseline_metrics, + )) + .try_flatten(), + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +async fn execute_build_hash_map( + context: Arc, + mut input: SendableRecordBatchStream, + keys: Vec>, + metrics: BaselineMetrics, +) -> Result { + let elapsed_compute = metrics.elapsed_compute().clone(); + let mut timer = elapsed_compute.timer(); + + let mut data_batches = vec![]; + let data_schema = input.schema(); + + // collect all input batches + while let Some(batch) = input.next_batch(Some(&mut timer)).await? { + data_batches.push(batch); + } + let data_batch = concat_batches(&data_schema, data_batches.iter())?; + + // build hash map + let hash_map_schema = join_hash_map_schema(&data_schema); + let hash_map = JoinHashMap::try_from_data_batch(data_batch, &keys)?; + drop(timer); + + // output hash map batches as stream + context.output_with_sender("BuildHashMap", hash_map_schema, move |sender| async move { + let mut timer = elapsed_compute.timer(); + sender + .send(Ok(hash_map.into_hash_map_batch()?), Some(&mut timer)) + .await; + Ok(()) + }) +} diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 201173c4d..de160af8f 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs @@ -15,90 +15,203 @@ use std::{ any::Any, fmt::{Debug, Formatter}, - sync::Arc, - task::Poll, - time::Duration, + future::Future, + pin::Pin, + sync::{Arc, Weak}, + time::{Duration, Instant}, }; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use blaze_jni_bridge::{ - conf, - conf::{BooleanConf, IntConf}, +use arrow::{ + array::RecordBatch, + compute::SortOptions, + datatypes::{DataType, SchemaRef}, }; +use async_trait::async_trait; use datafusion::{ - common::{Result, Statistics}, + common::{JoinSide, Result, Statistics}, execution::context::TaskContext, - logical_expr::JoinType, - physical_expr::PhysicalSortExpr, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ - expressions::Column, - joins::{ - utils::{build_join_schema, check_join_is_valid, JoinFilter, JoinOn}, - HashJoinExec, PartitionMode, - }, - memory::MemoryStream, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + joins::utils::JoinOn, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, Time}, stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }, }; -use datafusion_ext_commons::{df_execution_err, downcast_any}; -use futures::{stream::once, StreamExt, TryStreamExt}; +use datafusion_ext_commons::{ + batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput, +}; +use futures::{StreamExt, TryStreamExt}; +use hashbrown::HashMap; +use once_cell::sync::OnceCell; use parking_lot::Mutex; -use crate::{sort_exec::SortExec, sort_merge_join_exec::SortMergeJoinExec}; +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + column_pruning::ExecuteWithColumnPruning, + output::{TaskOutputter, WrappedRecordBatchSender}, + }, + joins::{ + bhj::{ + full_join::{ + LProbedFullOuterJoiner, LProbedInnerJoiner, LProbedLeftJoiner, LProbedRightJoiner, + RProbedFullOuterJoiner, RProbedInnerJoiner, RProbedLeftJoiner, RProbedRightJoiner, + }, + semi_join::{ + LProbedExistenceJoiner, LProbedLeftAntiJoiner, LProbedLeftSemiJoiner, + LProbedRightAntiJoiner, LProbedRightSemiJoiner, RProbedExistenceJoiner, + RProbedLeftAntiJoiner, RProbedLeftSemiJoiner, RProbedRightAntiJoiner, + RProbedRightSemiJoiner, + }, + }, + join_hash_map::{join_data_schema, JoinHashMap}, + join_utils::{JoinType, JoinType::*}, + JoinParams, JoinProjection, + }, +}; #[derive(Debug)] pub struct BroadcastJoinExec { - /// Left sorted joining execution plan left: Arc, - /// Right sorting joining execution plan right: Arc, - /// Set of common columns used to join on on: JoinOn, - /// How the join is performed join_type: JoinType, - /// Optional filter before outputting - join_filter: Option, - /// The schema once the join is applied + broadcast_side: JoinSide, schema: SchemaRef, - /// Execution metrics + cached_build_hash_map_id: Option, metrics: ExecutionPlanMetricsSet, } impl BroadcastJoinExec { pub fn try_new( + schema: SchemaRef, left: Arc, right: Arc, on: JoinOn, join_type: JoinType, - join_filter: Option, + broadcast_side: JoinSide, + cached_build_hash_map_id: Option, ) -> Result { - if matches!( - join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti, - ) { - if join_filter.is_some() { - df_execution_err!("Semi/Anti join with filter is not supported yet")?; - } - } - - let left_schema = left.schema(); - let right_schema = right.schema(); - - check_join_is_valid(&left_schema, &right_schema, &on)?; - let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); - Ok(Self { left, right, on, join_type, - join_filter, + broadcast_side, schema, + cached_build_hash_map_id, metrics: ExecutionPlanMetricsSet::new(), }) } + + fn create_join_params(&self, projection: &[usize]) -> Result { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + let (left_keys, right_keys): (Vec, Vec) = + self.on.iter().cloned().unzip(); + let key_data_types: Vec = self + .on + .iter() + .map(|(left_key, right_key)| { + Ok({ + let left_dt = left_key.data_type(&left_schema)?; + let right_dt = right_key.data_type(&right_schema)?; + if left_dt != right_dt { + df_execution_err!( + "join key data type differs {left_dt:?} <-> {right_dt:?}" + )?; + } + left_dt + }) + }) + .collect::>()?; + + let projection = JoinProjection::try_new( + self.join_type, + &self.schema, + &match self.broadcast_side { + JoinSide::Left => join_data_schema(&left_schema), + JoinSide::Right => left_schema.clone(), + }, + &match self.broadcast_side { + JoinSide::Left => right_schema.clone(), + JoinSide::Right => join_data_schema(&right_schema), + }, + projection, + )?; + + Ok(JoinParams { + join_type: self.join_type, + left_schema, + right_schema, + output_schema: self.schema(), + left_keys, + right_keys, + batch_size: batch_size(), + sort_options: vec![SortOptions::default(); self.on.len()], + projection, + key_data_types, + }) + } + + fn execute_with_projection( + &self, + partition: usize, + context: Arc, + projection: Vec, + ) -> Result { + let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); + let join_params = self.create_join_params(&projection)?; + let left = self.left.execute(partition, context.clone())?; + let right = self.right.execute(partition, context.clone())?; + let broadcast_side = self.broadcast_side; + let cached_build_hash_map_id = self.cached_build_hash_map_id.clone(); + + // stat probed side + let input_batch_stat = + InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?; + let (left, right) = match broadcast_side { + JoinSide::Left => (left, stat_input(input_batch_stat, right)?), + JoinSide::Right => (stat_input(input_batch_stat, left)?, right), + }; + + let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); + let output_stream = Box::pin(RecordBatchStreamAdapter::new( + join_params.projection.schema.clone(), + futures::stream::once(async move { + context_cloned.output_with_sender( + "BroadcastJoin", + join_params.projection.schema.clone(), + move |sender| { + execute_join( + left, + right, + join_params, + broadcast_side, + cached_build_hash_map_id, + metrics_cloned, + sender, + ) + }, + ) + }) + .try_flatten(), + )); + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) + } +} + +impl ExecuteWithColumnPruning for BroadcastJoinExec { + fn execute_projected( + &self, + partition: usize, + context: Arc, + projection: &[usize], + ) -> Result { + self.execute_with_projection(partition, context, projection.to_vec()) + } } impl ExecutionPlan for BroadcastJoinExec { @@ -111,7 +224,10 @@ impl ExecutionPlan for BroadcastJoinExec { } fn output_partitioning(&self) -> Partitioning { - self.right.output_partitioning() + match self.broadcast_side { + JoinSide::Left => self.right.output_partitioning(), + JoinSide::Right => self.left.output_partitioning(), + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -127,11 +243,13 @@ impl ExecutionPlan for BroadcastJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(Self::try_new( + self.schema.clone(), children[0].clone(), children[1].clone(), self.on.iter().cloned().collect(), self.join_type, - self.join_filter.clone(), + self.broadcast_side, + None, )?)) } @@ -140,21 +258,8 @@ impl ExecutionPlan for BroadcastJoinExec { partition: usize, context: Arc, ) -> Result { - let stream = execute_broadcast_join( - self.left.clone(), - self.right.clone(), - partition, - context, - self.on.clone(), - self.join_type, - self.join_filter.clone(), - BaselineMetrics::new(&self.metrics, partition), - ); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(stream).try_flatten(), - ))) + let projection = (0..self.schema.fields().len()).collect(); + self.execute_with_projection(partition, context, projection) } fn metrics(&self) -> Option { @@ -172,221 +277,188 @@ impl DisplayAs for BroadcastJoinExec { } } -async fn execute_broadcast_join( - left: Arc, - right: Arc, - partition: usize, - context: Arc, - on: JoinOn, - join_type: JoinType, - join_filter: Option, - metrics: BaselineMetrics, -) -> Result { - let enabled_fallback_to_smj = conf::BHJ_FALLBACKS_TO_SMJ_ENABLE.value()?; - let bhj_num_rows_limit = conf::BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD.value()? as usize; - let bhj_mem_size_limit = conf::BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD.value()? as usize; - - // if broadcasted size is small enough, use hash join - // otherwise use sort-merge join - #[derive(Debug)] - enum JoinMode { - Hash, - SortMerge, - } - let mut join_mode = JoinMode::Hash; - - let left_schema = left.schema(); - let mut left = left; - - if enabled_fallback_to_smj { - let mut left_stream = left.execute(0, context.clone())?.fuse(); - let mut left_cached: Vec = vec![]; - let mut left_num_rows = 0; - let mut left_mem_size = 0; - - // read and cache batches from broadcasted side until reached limits - while let Some(batch) = left_stream.next().await.transpose()? { - left_num_rows += batch.num_rows(); - left_mem_size += batch.get_array_memory_size(); - left_cached.push(batch); - if left_num_rows > bhj_num_rows_limit || left_mem_size > bhj_mem_size_limit { - join_mode = JoinMode::SortMerge; - break; - } - } - - // convert left cached and rest batches into execution plan - let left_cached_stream: SendableRecordBatchStream = Box::pin(MemoryStream::try_new( - left_cached, - left_schema.clone(), - None, - )?); - let left_rest_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - left_schema.clone(), - left_stream, - )); - let left_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - left_schema.clone(), - left_cached_stream.chain(left_rest_stream), - )); - left = Arc::new(RecordBatchStreamsWrapperExec { - schema: left_schema.clone(), - stream: Mutex::new(Some(left_stream)), - output_partitioning: right.output_partitioning(), - }); - } - - match join_mode { - JoinMode::Hash => { - let join = Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), - on, - join_filter, - &join_type, - PartitionMode::CollectLeft, - false, - )?); - log::info!("BroadcastJoin is using hash join mode: {:?}", &join); - - let join_schema = join.schema(); - let completed = join - .execute(partition, context)? - .chain(futures::stream::poll_fn(move |_| { - // update metrics - let join_metrics = join.metrics().unwrap(); - metrics.record_output(join_metrics.output_rows().unwrap_or(0)); - metrics.elapsed_compute().add_duration(Duration::from_nanos( - [ - join_metrics - .sum_by_name("build_time") - .map(|v| v.as_usize() as u64), - join_metrics - .sum_by_name("join_time") - .map(|v| v.as_usize() as u64), - ] - .into_iter() - .flatten() - .sum(), - )); - Poll::Ready(None) - })); - Ok(Box::pin(RecordBatchStreamAdapter::new( - join_schema, - completed, - ))) +async fn execute_join( + left: SendableRecordBatchStream, + right: SendableRecordBatchStream, + join_params: JoinParams, + broadcast_side: JoinSide, + cached_build_hash_map_id: Option, + metrics: Arc, + sender: Arc, +) -> Result<()> { + let start_time = Instant::now(); + let mut excluded_time_ns = 0; + let poll_time = Time::new(); + + let (mut probed, _keys, mut joiner): (_, _, Pin>) = match broadcast_side + { + JoinSide::Left => { + let right_schema = right.schema(); + let mut right_peeked = Box::pin(right.peekable()); + let (_, lmap_result) = futures::join!( + // fetch two sides asynchronously + async { + let timer = poll_time.timer(); + right_peeked.as_mut().peek().await; + drop(timer); + }, + collect_join_hash_map( + cached_build_hash_map_id, + left, + &join_params.left_keys, + poll_time.clone(), + ), + ); + let lmap = lmap_result?; + ( + Box::pin(RecordBatchStreamAdapter::new(right_schema, right_peeked)), + join_params.right_keys.clone(), + match join_params.join_type { + Inner => Box::pin(RProbedInnerJoiner::new(join_params, lmap, sender)), + Left => Box::pin(RProbedLeftJoiner::new(join_params, lmap, sender)), + Right => Box::pin(RProbedRightJoiner::new(join_params, lmap, sender)), + Full => Box::pin(RProbedFullOuterJoiner::new(join_params, lmap, sender)), + LeftSemi => Box::pin(RProbedLeftSemiJoiner::new(join_params, lmap, sender)), + LeftAnti => Box::pin(RProbedLeftAntiJoiner::new(join_params, lmap, sender)), + RightSemi => Box::pin(RProbedRightSemiJoiner::new(join_params, lmap, sender)), + RightAnti => Box::pin(RProbedRightAntiJoiner::new(join_params, lmap, sender)), + Existence => Box::pin(RProbedExistenceJoiner::new(join_params, lmap, sender)), + }, + ) } - JoinMode::SortMerge => { - let sort_exprs: Vec = on - .iter() - .map(|(_col_left, col_right)| PhysicalSortExpr { - expr: Arc::new(Column::new( - "", - downcast_any!(col_right, Column) - .expect("requires column") - .index(), - )), - options: Default::default(), - }) - .collect(); - - let right_sorted = Arc::new(SortExec::new(right, sort_exprs.clone(), None)); - let join = Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right_sorted.clone(), - on, - join_type, - join_filter, - sort_exprs.into_iter().map(|se| se.options).collect(), - )?); - log::info!("BroadcastJoin is using sort-merge join mode: {:?}", &join); - - let join_schema = join.schema(); - let completed = join - .execute(partition, context)? - .chain(futures::stream::poll_fn(move |_| { - // update metrics - let right_sorted_metrics = right_sorted.metrics().unwrap(); - let join_metrics = join.metrics().unwrap(); - metrics.record_output(join_metrics.output_rows().unwrap_or(0)); - metrics.elapsed_compute().add_duration(Duration::from_nanos( - [ - right_sorted_metrics.elapsed_compute(), - join_metrics.elapsed_compute(), - ] - .into_iter() - .flatten() - .sum::() as u64, - )); - Poll::Ready(None) - })); - Ok(Box::pin(RecordBatchStreamAdapter::new( - join_schema, - completed, - ))) + JoinSide::Right => { + let left_schema = left.schema(); + let mut left_peeked = Box::pin(left.peekable()); + let (_, rmap_result) = futures::join!( + // fetch two sides asynchronizely + async { + let timer = poll_time.timer(); + left_peeked.as_mut().peek().await; + drop(timer); + }, + collect_join_hash_map( + cached_build_hash_map_id, + right, + &join_params.right_keys, + poll_time.clone(), + ), + ); + let rmap = rmap_result?; + ( + Box::pin(RecordBatchStreamAdapter::new(left_schema, left_peeked)), + join_params.left_keys.clone(), + match join_params.join_type { + Inner => Box::pin(LProbedInnerJoiner::new(join_params, rmap, sender)), + Left => Box::pin(LProbedLeftJoiner::new(join_params, rmap, sender)), + Right => Box::pin(LProbedRightJoiner::new(join_params, rmap, sender)), + Full => Box::pin(LProbedFullOuterJoiner::new(join_params, rmap, sender)), + LeftSemi => Box::pin(LProbedLeftSemiJoiner::new(join_params, rmap, sender)), + LeftAnti => Box::pin(LProbedLeftAntiJoiner::new(join_params, rmap, sender)), + RightSemi => Box::pin(LProbedRightSemiJoiner::new(join_params, rmap, sender)), + RightAnti => Box::pin(LProbedRightAntiJoiner::new(join_params, rmap, sender)), + Existence => Box::pin(LProbedExistenceJoiner::new(join_params, rmap, sender)), + }, + ) } + }; + + while let Some(batch) = { + let timer = poll_time.timer(); + let batch = probed.next().await.transpose()?; + drop(timer); + batch + } { + joiner.as_mut().join(batch).await?; } + joiner.as_mut().finish().await?; + metrics.record_output(joiner.num_output_rows()); + + excluded_time_ns += poll_time.value(); + excluded_time_ns += joiner.total_send_output_time(); + + // discount poll input and send output batch time + let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64; + join_time_ns -= excluded_time_ns as u64; + metrics + .elapsed_compute() + .add_duration(Duration::from_nanos(join_time_ns)); + Ok(()) } -pub struct RecordBatchStreamsWrapperExec { - pub schema: SchemaRef, - pub stream: Mutex>, - pub output_partitioning: Partitioning, +async fn collect_join_hash_map( + cached_build_hash_map_id: Option, + input: SendableRecordBatchStream, + key_exprs: &[PhysicalExprRef], + poll_time: Time, +) -> Result> { + Ok(match cached_build_hash_map_id { + Some(cached_id) => { + get_cached_join_hash_map(&cached_id, || async { + collect_join_hash_map_without_caching(input, key_exprs, poll_time).await + }) + .await? + } + None => { + let map = collect_join_hash_map_without_caching(input, key_exprs, poll_time).await?; + Arc::new(map) + } + }) } -impl Debug for RecordBatchStreamsWrapperExec { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "RecordBatchStreamsWrapper") +async fn collect_join_hash_map_without_caching( + mut input: SendableRecordBatchStream, + key_exprs: &[PhysicalExprRef], + poll_time: Time, +) -> Result { + let mut hash_map_batches = vec![]; + while let Some(batch) = { + let timer = poll_time.timer(); + let batch = input.next().await.transpose()?; + drop(timer); + batch + } { + hash_map_batches.push(batch); } -} - -impl DisplayAs for RecordBatchStreamsWrapperExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "RecordBatchStreamsWrapper") + match hash_map_batches.len() { + 0 => Ok(JoinHashMap::try_new_empty(input.schema(), key_exprs)?), + 1 => Ok(JoinHashMap::try_from_hash_map_batch( + hash_map_batches[0].clone(), + key_exprs, + )?), + n => df_execution_err!("expect zero or one hash map batch, got {n}"), } } -impl ExecutionPlan for RecordBatchStreamsWrapperExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - self.output_partitioning.clone() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } +#[async_trait] +pub trait Joiner { + async fn join(self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()>; + async fn finish(self: Pin<&mut Self>) -> Result<()>; - fn children(&self) -> Vec> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - let stream = std::mem::take(&mut *self.stream.lock()); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), - Box::pin(futures::stream::iter(stream).flatten()), - ))) - } + fn total_send_output_time(&self) -> usize; + fn num_output_rows(&self) -> usize; +} - fn statistics(&self) -> Result { - unimplemented!() +async fn get_cached_join_hash_map> + Send>( + cached_id: &str, + init: impl FnOnce() -> Fut, +) -> Result> { + type Slot = Arc>>; + static CACHED_JOIN_HASH_MAP: OnceCell>>> = OnceCell::new(); + + // TODO: remove expired keys from cached join hash map + let cached_join_hash_map = CACHED_JOIN_HASH_MAP.get_or_init(|| Arc::default()); + let slot = cached_join_hash_map + .lock() + .entry(cached_id.to_string()) + .or_default() + .clone(); + + let mut slot = slot.lock().await; + if let Some(cached) = slot.upgrade() { + Ok(cached) + } else { + let new = Arc::new(init().await?); + *slot = Arc::downgrade(&new); + Ok(new) } } diff --git a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs deleted file mode 100644 index b52e77f00..000000000 --- a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright 2022 The Blaze Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{any::Any, fmt::Formatter, sync::Arc}; - -use arrow::datatypes::SchemaRef; -use datafusion::{ - common::{JoinType, Result, Statistics}, - execution::{SendableRecordBatchStream, TaskContext}, - physical_expr::{Partitioning, PhysicalSortExpr}, - physical_plan::{ - joins::{ - utils::{build_join_schema, check_join_is_valid, JoinFilter}, - NestedLoopJoinExec, - }, - memory::MemoryExec, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - stream::RecordBatchStreamAdapter, - DisplayAs, DisplayFormatType, ExecutionPlan, - }, -}; -use datafusion_ext_commons::batch_size; -use futures::{stream::once, StreamExt, TryStreamExt}; -use parking_lot::Mutex; - -use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec; - -#[derive(Debug)] -pub struct BroadcastNestedLoopJoinExec { - left: Arc, - right: Arc, - join_type: JoinType, - filter: Option, - schema: SchemaRef, - metrics: ExecutionPlanMetricsSet, -} - -impl BroadcastNestedLoopJoinExec { - pub fn try_new( - left: Arc, - right: Arc, - join_type: JoinType, - filter: Option, - ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, &[])?; - let (schema, _column_indices) = build_join_schema(&left_schema, &right_schema, &join_type); - - Ok(Self { - left, - right, - filter, - join_type, - schema: Arc::new(schema), - metrics: ExecutionPlanMetricsSet::new(), - }) - } -} - -impl DisplayAs for BroadcastNestedLoopJoinExec { - fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "BroadcastNestedLoopJoin") - } -} - -impl ExecutionPlan for BroadcastNestedLoopJoinExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - if left_is_build_side(self.join_type) { - self.right.output_partitioning() - } else { - self.left.output_partitioning() - } - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(Self::try_new( - children[0].clone(), - children[1].clone(), - self.join_type, - self.filter.clone(), - )?)) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let joined = Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(execute_join( - partition, - context, - self.left.clone(), - self.right.clone(), - self.join_type, - self.filter.clone(), - self.metrics.clone(), - )) - .try_flatten(), - )); - Ok(joined) - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Result { - todo!() - } -} - -async fn execute_join( - partition: usize, - context: Arc, - left: Arc, - right: Arc, - join_type: JoinType, - filter: Option, - metrics: ExecutionPlanMetricsSet, -) -> Result { - // inner side - let mut inner_stream = if left_is_build_side(join_type) { - left.execute(partition, context.clone())? - } else { - right.execute(partition, context.clone())? - }; - let inner_schema = inner_stream.schema(); - let mut inner_batches = vec![]; - while let Some(batch) = inner_stream.next().await.transpose()? { - inner_batches.push(batch); - } - - let inner_batch_max_num_rows = inner_batches - .iter() - .map(|batch| batch.num_rows()) - .max() - .unwrap_or(0); - let inner_batch_max_mem_size = inner_batches - .iter() - .map(|batch| batch.get_array_memory_size()) - .max() - .unwrap_or(0); - - let target_output_num_rows = batch_size(); - let target_output_mem_size = 1 << 26; // 64MB - let inner_exec: Arc = - Arc::new(MemoryExec::try_new(&[inner_batches], inner_schema, None)?); - - // outer side - let (outer_schema, outer_partitioning, outer_stream) = if left_is_build_side(join_type) { - ( - right.schema(), - right.output_partitioning(), - right.execute(partition, context.clone())?, - ) - } else { - ( - left.schema(), - left.output_partitioning(), - left.execute(partition, context.clone())?, - ) - }; - let chunked_outer_stream = Box::pin(RecordBatchStreamAdapter::new( - outer_schema.clone(), - outer_stream.flat_map(move |batch_result| match batch_result { - Ok(batch) => { - let batch_num_rows = batch.num_rows(); - let batch_mem_size = batch.get_array_memory_size(); - let output_num_rows = batch_num_rows * inner_batch_max_num_rows; - let output_mem_size = batch_num_rows * inner_batch_max_mem_size - + batch_mem_size * inner_batch_max_num_rows; - let chunk_count = std::cmp::min( - (output_num_rows / target_output_num_rows).max(1), - (output_mem_size / target_output_mem_size).max(1), - ); - let chunk_len = (batch_num_rows / chunk_count).max(1); - - let mut chunks = vec![]; - for beg in (0..batch.num_rows()).step_by(chunk_len) { - chunks.push(Ok(batch.slice(beg, chunk_len.min(batch.num_rows() - beg)))); - } - futures::stream::iter(chunks) - } - Err(err) => futures::stream::iter(vec![Err(err)]), - }), - )); - let outer_exec: Arc = Arc::new(RecordBatchStreamsWrapperExec { - schema: outer_schema, - stream: Mutex::new(Some(chunked_outer_stream)), - output_partitioning: outer_partitioning, - }); - - // join with datafusion's builtin NestedLoopJoinExec - let nlj = if left_is_build_side(join_type) { - NestedLoopJoinExec::try_new(inner_exec, outer_exec, filter, &join_type)? - } else { - NestedLoopJoinExec::try_new(outer_exec, inner_exec, filter, &join_type)? - }; - let joined = nlj.execute(partition, context)?; - - let baseline_metrics = BaselineMetrics::new(&metrics, partition); - let output_stream = Box::pin(RecordBatchStreamAdapter::new( - joined.schema(), - joined.map(move |batch_result| { - if let Ok(batch) = &batch_result { - baseline_metrics.record_output(batch.num_rows()); - } - batch_result - }), - )); - Ok(output_stream) -} - -fn left_is_build_side(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full - ) -} diff --git a/native-engine/datafusion-ext-plans/src/common/batch_selection.rs b/native-engine/datafusion-ext-plans/src/common/batch_selection.rs index 6aa839551..a5e789cd5 100644 --- a/native-engine/datafusion-ext-plans/src/common/batch_selection.rs +++ b/native-engine/datafusion-ext-plans/src/common/batch_selection.rs @@ -41,16 +41,33 @@ pub fn take_batch_opt( take_batch_internal(batch, indices) } +pub fn take_cols( + cols: &[ArrayRef], + indices: impl IntoIterator, +) -> Result> { + let indices: UInt32Array = + PrimitiveArray::from_iter(indices.into_iter().map(|idx| idx.to_u32().unwrap())); + take_cols_internal(cols, &indices) +} + +pub fn take_cols_opt( + cols: &[ArrayRef], + indices: impl IntoIterator>, +) -> Result> { + let indices: UInt32Array = PrimitiveArray::from_iter( + indices + .into_iter() + .map(|opt| opt.map(|idx| idx.to_u32().unwrap())), + ); + take_cols_internal(cols, &indices) +} + fn take_batch_internal(batch: RecordBatch, indices: UInt32Array) -> Result { let taken_num_batch_rows = indices.len(); let schema = batch.schema(); - let cols = batch.columns().to_vec(); - drop(batch); // we would like to release batch as soon as possible + let cols = batch.columns(); - let cols = cols - .into_iter() - .map(|c| Ok(arrow::compute::take(&c, &indices, None)?)) - .collect::>()?; + let cols = take_cols_internal(cols, &indices)?; drop(indices); let taken = RecordBatch::try_new_with_options( @@ -61,6 +78,14 @@ fn take_batch_internal(batch: RecordBatch, indices: UInt32Array) -> Result Result> { + let cols = cols + .into_iter() + .map(|c| Ok(arrow::compute::take(&c, indices, None)?)) + .collect::>()?; + Ok(cols) +} + pub fn interleave_batches( schema: SchemaRef, batches: &[RecordBatch], diff --git a/native-engine/datafusion-ext-plans/src/common/output.rs b/native-engine/datafusion-ext-plans/src/common/output.rs index b1a0a2807..d888026ef 100644 --- a/native-engine/datafusion-ext-plans/src/common/output.rs +++ b/native-engine/datafusion-ext-plans/src/common/output.rs @@ -20,6 +20,7 @@ use std::{ }; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use async_trait::async_trait; use blaze_jni_bridge::is_task_running; use datafusion::{ common::Result, @@ -221,3 +222,34 @@ impl TaskOutputter for Arc { WrappedRecordBatchSender::cancel_task(self); } } + +#[async_trait] +pub trait NextBatchWithTimer { + async fn next_batch( + &mut self, + stop_timer: Option<&mut ScopedTimerGuard<'_>>, + ) -> Result>; +} + +#[async_trait] +impl NextBatchWithTimer for SendableRecordBatchStream { + async fn next_batch( + &mut self, + stop_timer: Option<&mut ScopedTimerGuard<'_>>, + ) -> Result> { + struct StopScopedTimerGuard<'a, 'z>(&'a mut ScopedTimerGuard<'z>); + impl<'a, 'z> StopScopedTimerGuard<'a, 'z> { + fn new(timer: &'a mut ScopedTimerGuard<'z>) -> Self { + timer.stop(); + Self(timer) + } + } + impl Drop for StopScopedTimerGuard<'_, '_> { + fn drop(&mut self) { + self.0.restart(); + } + } + let _stop_timer = stop_timer.map(|timer| StopScopedTimerGuard::new(timer)); + self.next().await.transpose() + } +} diff --git a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs index fb8651ce1..9688e5af2 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs @@ -192,8 +192,8 @@ pub async fn read_ipc( })); while let Some(batch) = { - let reader_cloned = reader.clone(); - tokio::task::spawn_blocking(move || reader_cloned.clone().lock().read_batch()) + let reader = reader.clone(); + tokio::task::spawn_blocking(move || reader.lock().read_batch()) .await .or_else(|err| df_execution_err!("{err}"))?? } { diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs new file mode 100644 index 000000000..ca51b5629 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs @@ -0,0 +1,324 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use arrow::array::{new_null_array, ArrayRef, RecordBatch}; +use async_trait::async_trait; +use bitvec::{bitvec, prelude::BitVec}; +use datafusion::{common::Result, physical_plan::metrics::Time}; + +use crate::{ + broadcast_join_exec::Joiner, + common::{batch_selection::take_cols, output::WrappedRecordBatchSender}, + joins::{ + bhj::{ + filter_joined_indices, + full_join::ProbeSide::{L, R}, + ProbeSide, + }, + join_hash_map::{join_create_hashes, JoinHashMap}, + JoinParams, + }, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + probe_side: ProbeSide, + probe_side_outer: bool, + build_side_outer: bool, +} + +impl JoinerParams { + const fn new(probe_side: ProbeSide, probe_side_outer: bool, build_side_outer: bool) -> Self { + Self { + probe_side, + probe_side_outer, + build_side_outer, + } + } +} + +const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false); +const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false); +const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true); +const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true); + +const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false); +const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true); +const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false); +const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true); + +pub type LProbedInnerJoiner = FullJoiner; +pub type LProbedLeftJoiner = FullJoiner; +pub type LProbedRightJoiner = FullJoiner; +pub type LProbedFullOuterJoiner = FullJoiner; +pub type RProbedInnerJoiner = FullJoiner; +pub type RProbedLeftJoiner = FullJoiner; +pub type RProbedRightJoiner = FullJoiner; +pub type RProbedFullOuterJoiner = FullJoiner; + +pub struct FullJoiner { + join_params: JoinParams, + output_sender: Arc, + map: Arc, + map_joined: BitVec, + send_output_time: Time, + output_rows: AtomicUsize, +} + +impl FullJoiner

{ + pub fn new( + join_params: JoinParams, + map: Arc, + output_sender: Arc, + ) -> Self { + let map_joined = bitvec![0; map.data_batch().num_rows()]; + Self { + join_params, + output_sender, + map, + map_joined, + send_output_time: Time::default(), + output_rows: AtomicUsize::new(0), + } + } + + fn create_probed_key_columns(&self, probed_batch: &RecordBatch) -> Result> { + let probed_key_exprs = match P.probe_side { + L => &self.join_params.left_keys, + R => &self.join_params.right_keys, + }; + let probed_key_columns: Vec = probed_key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(probed_batch)? + .into_array(probed_batch.num_rows())?) + }) + .collect::>()?; + Ok(probed_key_columns) + } + + async fn flush(&self, probe_cols: Vec, build_cols: Vec) -> Result<()> { + let output_batch = RecordBatch::try_new( + self.join_params.output_schema.clone(), + match P.probe_side { + L => [probe_cols, build_cols].concat(), + R => [build_cols, probe_cols].concat(), + }, + )?; + self.output_rows.fetch_add(output_batch.num_rows(), Relaxed); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + Ok(()) + } + + async fn flush_hash_joined( + mut self: Pin<&mut Self>, + probed_batch: &RecordBatch, + probed_key_columns: &[ArrayRef], + probed_joined: &mut BitVec, + mut hash_joined_probe_indices: Vec, + mut hash_joined_build_indices: Vec, + ) -> Result<()> { + filter_joined_indices( + probed_key_columns, + self.map.key_columns(), + &mut hash_joined_probe_indices, + &mut hash_joined_build_indices, + )?; + let probe_indices = hash_joined_probe_indices; + let build_indices = hash_joined_build_indices; + + let pprojected = match P.probe_side { + L => self + .join_params + .projection + .project_left(probed_batch.columns()), + R => self + .join_params + .projection + .project_right(probed_batch.columns()), + }; + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + for &idx in &probe_indices { + probed_joined.set(idx as usize, true); + } + let pcols = if probe_indices.len() == probed_batch.num_rows() && probed_joined.all() { + // fast path for the case where every probed records have 1-to-1 joined + pprojected + } else { + take_cols(&pprojected, probe_indices)? + }; + + for &idx in &build_indices { + self.map_joined.set(idx as usize, true); + } + let bcols = take_cols(&mprojected, build_indices)?; + + self.flush(pcols, bcols).await?; + Ok(()) + } +} + +#[async_trait] +impl Joiner for FullJoiner

{ + async fn join(mut self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()> { + let mut hash_joined_probe_indices: Vec = vec![]; + let mut hash_joined_build_indices: Vec = vec![]; + let mut probed_joined = bitvec![0; probed_batch.num_rows()]; + let batch_size = self.join_params.batch_size.max(probed_batch.num_rows()); + + let probed_key_columns = self.create_probed_key_columns(&probed_batch)?; + let probed_hashes = join_create_hashes(probed_batch.num_rows(), &probed_key_columns)?; + + // join by hash code + for (row_idx, &hash) in probed_hashes.iter().enumerate() { + let mut maybe_joined = false; + if let Some(entries) = self.map.entry_indices(hash) { + for map_idx in entries { + hash_joined_probe_indices.push(row_idx as u32); + hash_joined_build_indices.push(map_idx); + } + maybe_joined = true; + } + + if maybe_joined && hash_joined_probe_indices.len() > batch_size { + self.as_mut() + .flush_hash_joined( + &probed_batch, + &probed_key_columns, + &mut probed_joined, + std::mem::take(&mut hash_joined_probe_indices), + std::mem::take(&mut hash_joined_build_indices), + ) + .await?; + } + } + if !hash_joined_probe_indices.is_empty() { + self.as_mut() + .flush_hash_joined( + &probed_batch, + &probed_key_columns, + &mut probed_joined, + hash_joined_probe_indices, + hash_joined_build_indices, + ) + .await?; + } + + // output unjoined rows of probed side + if P.probe_side_outer { + let probed_unjoined_indices = probed_joined + .iter() + .enumerate() + .filter(|(_, joined)| !**joined) + .map(|(idx, _)| idx as u32) + .collect::>(); + + let pprojected = match P.probe_side { + L => self + .join_params + .projection + .project_left(probed_batch.columns()), + R => self + .join_params + .projection + .project_right(probed_batch.columns()), + }; + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + + let bcols = mprojected + .iter() + .map(|col| new_null_array(col.data_type(), probed_unjoined_indices.len())) + .collect::>(); + + let pcols = take_cols(&pprojected, probed_unjoined_indices)?; + self.as_mut().flush(pcols, bcols).await?; + } + Ok(()) + } + + async fn finish(mut self: Pin<&mut Self>) -> Result<()> { + // output unjoined rows of probed side + let map_joined = std::mem::take(&mut self.map_joined); + if P.build_side_outer { + let map_unjoined_indices = map_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| !joined) + .map(|(idx, _)| idx as u32) + .collect::>(); + + let pschema = match P.probe_side { + L => &self.join_params.left_schema, + R => &self.join_params.right_schema, + }; + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + + let pcols = pschema + .fields() + .iter() + .map(|field| new_null_array(field.data_type(), map_unjoined_indices.len())) + .collect::>(); + let bcols = take_cols(&mprojected, map_unjoined_indices)?; + self.as_mut().flush(pcols, bcols).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows.load(Relaxed) + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs new file mode 100644 index 000000000..57d934cd5 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs @@ -0,0 +1,146 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::{ + array::*, + datatypes::{DataType, IntervalUnit, TimeUnit}, +}; +use datafusion::common::Result; +use datafusion_ext_commons::{df_execution_err, downcast_any}; + +pub mod full_join; +pub mod semi_join; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum ProbeSide { + L, + R, +} + +fn filter_joined_indices( + key_columns1: &[ArrayRef], + key_columns2: &[ArrayRef], + indices1: &mut Vec, + indices2: &mut Vec, +) -> Result<()> { + fn filter_one( + key_column1: &ArrayRef, + key_column2: &ArrayRef, + indices1: &mut Vec, + indices2: &mut Vec, + ) -> Result<()> { + macro_rules! filter_atomic { + ($cast_type:ty) => {{ + let col1 = downcast_any!(key_column1, $cast_type)?; + let col2 = downcast_any!(key_column2, $cast_type)?; + let mut valid_count = 0; + for i in 0..indices1.len() { + let idx1 = indices1[i] as usize; + let idx2 = indices2[i] as usize; + if col1.is_valid(idx1) && col2.is_valid(idx2) && { + let v1 = col1.value(idx1); + let v2 = col2.value(idx2); + v1 == v2 + } { + indices1[valid_count] = indices1[i]; + indices2[valid_count] = indices2[i]; + valid_count += 1; + } + } + indices1.truncate(valid_count); + indices2.truncate(valid_count); + }}; + } + + let dt1 = key_column1.data_type(); + let dt2 = key_column2.data_type(); + if dt1 != dt2 { + return df_execution_err!("join key data type not matched: {dt1:?} <-> {dt2:?}"); + } + match dt1 { + DataType::Null => { + indices1.clear(); + indices2.clear(); + } + DataType::Boolean => filter_atomic!(BooleanArray), + DataType::Int8 => filter_atomic!(Int8Array), + DataType::Int16 => filter_atomic!(Int16Array), + DataType::Int32 => filter_atomic!(Int32Array), + DataType::Int64 => filter_atomic!(Int64Array), + DataType::UInt8 => filter_atomic!(UInt8Array), + DataType::UInt16 => filter_atomic!(UInt16Array), + DataType::UInt32 => filter_atomic!(UInt32Array), + DataType::UInt64 => filter_atomic!(UInt64Array), + DataType::Float16 => filter_atomic!(Float16Array), + DataType::Float32 => filter_atomic!(Float32Array), + DataType::Float64 => filter_atomic!(Float64Array), + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => filter_atomic!(TimestampSecondArray), + TimeUnit::Millisecond => filter_atomic!(TimestampMillisecondArray), + TimeUnit::Microsecond => filter_atomic!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(TimestampNanosecondArray), + }, + DataType::Date32 => filter_atomic!(Date32Array), + DataType::Date64 => filter_atomic!(Date64Array), + DataType::Time32(unit) => match unit { + TimeUnit::Second => filter_atomic!(Time32SecondArray), + TimeUnit::Millisecond => filter_atomic!(Time32MillisecondArray), + TimeUnit::Microsecond => filter_atomic!(Time32MillisecondArray), + TimeUnit::Nanosecond => filter_atomic!(Time32MillisecondArray), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => filter_atomic!(Time64MicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(Time64NanosecondArray), + _ => return df_execution_err!("unsupported time64 unit: {unit:?}"), + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => filter_atomic!(DurationSecondArray), + TimeUnit::Millisecond => filter_atomic!(DurationMillisecondArray), + TimeUnit::Microsecond => filter_atomic!(DurationMicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(DurationNanosecondArray), + }, + DataType::Interval(unit) => match unit { + IntervalUnit::YearMonth => filter_atomic!(IntervalYearMonthArray), + IntervalUnit::DayTime => filter_atomic!(IntervalDayTimeArray), + IntervalUnit::MonthDayNano => filter_atomic!(IntervalMonthDayNanoArray), + }, + DataType::Binary => filter_atomic!(BinaryArray), + DataType::FixedSizeBinary(_) => filter_atomic!(FixedSizeBinaryArray), + DataType::LargeBinary => filter_atomic!(LargeBinaryArray), + DataType::Utf8 => filter_atomic!(StringArray), + DataType::LargeUtf8 => filter_atomic!(LargeStringArray), + DataType::List(_) => filter_atomic!(ListArray), + DataType::FixedSizeList(..) => filter_atomic!(FixedSizeListArray), + DataType::LargeList(_) => filter_atomic!(LargeListArray), + DataType::Struct(_) => filter_joined_indices( + key_column1.as_struct().columns(), + key_column2.as_struct().columns(), + indices1, + indices2, + )?, + DataType::Decimal128(..) => filter_atomic!(Decimal128Array), + DataType::Decimal256(..) => filter_atomic!(Decimal256Array), + DataType::Map(..) => filter_atomic!(MapArray), + dt => { + return df_execution_err!("unsupported data type: {dt:?}"); + } + } + Ok(()) + } + + for (key_column1, key_column2) in key_columns1.iter().zip(key_columns2) { + filter_one(key_column1, key_column2, indices1, indices2)?; + } + Ok(()) +} diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs new file mode 100644 index 000000000..8c168f00c --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs @@ -0,0 +1,283 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use arrow::array::{ArrayRef, BooleanArray, RecordBatch}; +use async_trait::async_trait; +use bitvec::{bitvec, prelude::BitVec}; +use datafusion::{common::Result, physical_plan::metrics::Time}; + +use crate::{ + broadcast_join_exec::Joiner, + common::{batch_selection::take_cols, output::WrappedRecordBatchSender}, + joins::{ + bhj::{ + filter_joined_indices, + semi_join::{ + ProbeSide::{L, R}, + SemiMode::{Anti, Existence, Semi}, + }, + ProbeSide, + }, + join_hash_map::{join_create_hashes, JoinHashMap}, + JoinParams, + }, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum SemiMode { + Semi, + Anti, + Existence, +} + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + probe_side: ProbeSide, + probe_is_join_side: bool, + mode: SemiMode, +} + +impl JoinerParams { + const fn new(probe_side: ProbeSide, probe_is_join_side: bool, mode: SemiMode) -> Self { + Self { + probe_side, + probe_is_join_side, + mode, + } + } +} + +const LEFT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(L, true, Semi); +const LEFT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(L, true, Anti); +const LEFT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(L, false, Semi); +const LEFT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(L, false, Anti); +const LEFT_PROBED_EXISTENCE: JoinerParams = JoinerParams::new(L, true, Existence); +const RIGHT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(R, false, Semi); +const RIGHT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(R, false, Anti); +const RIGHT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true, Semi); +const RIGHT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(R, true, Anti); +const RIGHT_PROBED_EXISTENCE: JoinerParams = JoinerParams::new(R, false, Existence); + +pub type LProbedLeftSemiJoiner = SemiJoiner; +pub type LProbedLeftAntiJoiner = SemiJoiner; +pub type LProbedRightSemiJoiner = SemiJoiner; +pub type LProbedRightAntiJoiner = SemiJoiner; +pub type LProbedExistenceJoiner = SemiJoiner; +pub type RProbedLeftSemiJoiner = SemiJoiner; +pub type RProbedLeftAntiJoiner = SemiJoiner; +pub type RProbedRightSemiJoiner = SemiJoiner; +pub type RProbedRightAntiJoiner = SemiJoiner; +pub type RProbedExistenceJoiner = SemiJoiner; + +pub struct SemiJoiner { + join_params: JoinParams, + output_sender: Arc, + map_joined: BitVec, + map: Arc, + send_output_time: Time, + output_rows: AtomicUsize, +} + +impl SemiJoiner

{ + pub fn new( + join_params: JoinParams, + map: Arc, + output_sender: Arc, + ) -> Self { + let map_joined = bitvec![0; map.data_batch().num_rows()]; + Self { + join_params, + output_sender, + map, + map_joined, + send_output_time: Time::new(), + output_rows: AtomicUsize::new(0), + } + } + + fn create_probed_key_columns(&self, probed_batch: &RecordBatch) -> Result> { + let probed_key_exprs = match P.probe_side { + L => &self.join_params.left_keys, + R => &self.join_params.right_keys, + }; + let probed_key_columns: Vec = probed_key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(probed_batch)? + .into_array(probed_batch.num_rows())?) + }) + .collect::>()?; + Ok(probed_key_columns) + } + + async fn flush(&self, cols: Vec) -> Result<()> { + let output_batch = RecordBatch::try_new(self.join_params.output_schema.clone(), cols)?; + self.output_rows.fetch_add(output_batch.num_rows(), Relaxed); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + Ok(()) + } + + fn flush_hash_joined( + mut self: Pin<&mut Self>, + probed_key_columns: &[ArrayRef], + probed_joined: &mut BitVec, + mut hash_joined_probe_indices: Vec, + mut hash_joined_build_indices: Vec, + ) -> Result<()> { + filter_joined_indices( + probed_key_columns, + self.map.key_columns(), + &mut hash_joined_probe_indices, + &mut hash_joined_build_indices, + )?; + let probe_indices = hash_joined_probe_indices; + let build_indices = hash_joined_build_indices; + + for &idx in &probe_indices { + probed_joined.set(idx as usize, true); + } + for &idx in &build_indices { + self.map_joined.set(idx as usize, true); + } + Ok(()) + } +} + +#[async_trait] +impl Joiner for SemiJoiner

{ + async fn join(mut self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()> { + let mut hash_joined_probe_indices: Vec = vec![]; + let mut hash_joined_build_indices: Vec = vec![]; + let mut probed_joined = bitvec![0; probed_batch.num_rows()]; + + let probed_key_columns = self.create_probed_key_columns(&probed_batch)?; + let probed_hashes = join_create_hashes(probed_batch.num_rows(), &probed_key_columns)?; + + // join by hash code + for (row_idx, &hash) in probed_hashes.iter().enumerate() { + let mut maybe_joined = false; + if let Some(entries) = self.map.entry_indices(hash) { + for map_idx in entries { + hash_joined_probe_indices.push(row_idx as u32); + hash_joined_build_indices.push(map_idx); + } + maybe_joined = true; + } + + if maybe_joined && hash_joined_probe_indices.len() >= self.join_params.batch_size { + self.as_mut().flush_hash_joined( + &probed_key_columns, + &mut probed_joined, + std::mem::take(&mut hash_joined_probe_indices), + std::mem::take(&mut hash_joined_build_indices), + )?; + } + } + if !hash_joined_probe_indices.is_empty() { + self.as_mut().flush_hash_joined( + &probed_key_columns, + &mut probed_joined, + hash_joined_probe_indices, + hash_joined_build_indices, + )?; + } + + if P.probe_is_join_side { + let pprojected = match P.probe_side { + L => self + .join_params + .projection + .project_left(probed_batch.columns()), + R => self + .join_params + .projection + .project_right(probed_batch.columns()), + }; + let pcols = match P.mode { + Semi | Anti => { + let probed_indices = probed_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| (P.mode == Semi) ^ !joined) + .map(|(idx, _)| idx as u32) + .collect::>(); + take_cols(&pprojected, probed_indices)? + } + Existence => { + let exists_col = Arc::new(BooleanArray::from( + probed_joined.into_iter().collect::>(), + )); + [pprojected, vec![exists_col]].concat() + } + }; + self.as_mut().flush(pcols).await?; + } + Ok(()) + } + + async fn finish(mut self: Pin<&mut Self>) -> Result<()> { + if !P.probe_is_join_side { + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + let map_joined = std::mem::take(&mut self.map_joined); + let pcols = match P.mode { + Semi | Anti => { + let map_indices = map_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| (P.mode == Semi) ^ !joined) + .map(|(idx, _)| idx as u32) + .collect::>(); + take_cols(&mprojected, map_indices)? + } + Existence => { + let exists_col = Arc::new(BooleanArray::from( + map_joined.into_iter().collect::>(), + )); + [mprojected, vec![exists_col]].concat() + } + }; + self.as_mut().flush(pcols).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows.load(Relaxed) + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs new file mode 100644 index 000000000..8bd1a5731 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs @@ -0,0 +1,340 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + io::{Cursor, Read, Write}, + slice::{from_raw_parts, from_raw_parts_mut}, + sync::Arc, +}; + +use arrow::{ + array::{ArrayRef, AsArray, BinaryBuilder, RecordBatch}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, +}; +use byteorder::{NativeEndian, ReadBytesExt, WriteBytesExt}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::spark_hash::create_hashes; +use hashbrown::HashMap; +use itertools::Itertools; +use once_cell::sync::OnceCell; + +use crate::common::batch_selection::take_batch; + +pub struct Table { + entry_offsets: Vec, + entry_lens: Vec, + item_indices: Vec, + item_hashes: Vec, +} + +impl Table { + pub fn new_empty() -> Self { + let num_entries = Self::num_entries_of_rows(0); + Self { + entry_offsets: vec![0; num_entries], + entry_lens: vec![0; num_entries], + item_indices: vec![], + item_hashes: vec![], + } + } + + pub fn try_from_key_columns( + num_rows: usize, + data_batch: RecordBatch, + key_columns: &[ArrayRef], + ) -> Result<(Self, RecordBatch)> { + // returns the new data batch sorted by hashes + + assert!( + num_rows < 1073741824, + "join hash table: number of rows exceeded 2^30: {num_rows}" + ); + + let num_entries = Self::num_entries_of_rows(num_rows) as u32; + let item_hashes = join_create_hashes(num_rows, &key_columns)?; + + // sort record batch by hashes for better compression and data locality + let (indices, item_hashes): (Vec, Vec) = item_hashes + .into_iter() + .enumerate() + .sorted_unstable_by_key(|(_idx, hash)| *hash) + .unzip(); + let data_batch = take_batch(data_batch, indices)?; + + let mut entries_to_row_indices: HashMap> = HashMap::new(); + for (row_idx, hash) in item_hashes.iter().enumerate() { + let entry = hash % num_entries; + entries_to_row_indices + .entry(entry) + .or_default() + .push(row_idx as u32); + } + + let mut entry_offsets = Vec::with_capacity(num_entries as usize); + let mut entry_lens = Vec::with_capacity(num_entries as usize); + let mut item_indices = Vec::with_capacity(num_rows); + for entry in 0..num_entries { + match entries_to_row_indices.get(&entry) { + Some(row_indices) => { + entry_offsets.push(item_indices.len() as u32); + entry_lens.push(row_indices.len() as u32); + item_indices.extend_from_slice(row_indices); + } + None => { + entry_offsets.push(item_indices.len() as u32); + entry_lens.push(0); + } + } + } + let new = Self { + entry_offsets, + entry_lens, + item_indices, + item_hashes, + }; + Ok((new, data_batch)) + } + + pub fn try_from_raw_bytes(raw_bytes: &[u8]) -> Result { + let mut cursor = Cursor::new(raw_bytes); + let num_rows = cursor.read_u32::()? as usize; + let num_entries = Self::num_entries_of_rows(num_rows); + + let mut new = Self { + entry_offsets: vec![0; num_entries], + entry_lens: vec![0; num_entries], + item_indices: vec![0; num_rows], + item_hashes: vec![0; num_rows], + }; + + unsafe { + // safety: read integer arrays as raw bytes + cursor.read_exact(from_raw_parts_mut( + new.entry_offsets.as_mut_ptr() as *mut u8, + num_entries * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.entry_lens.as_mut_ptr() as *mut u8, + num_entries * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.item_indices.as_mut_ptr() as *mut u8, + num_rows * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.item_hashes.as_mut_ptr() as *mut u8, + num_rows * 4, + ))?; + } + Ok(new) + } + + pub fn try_into_raw_bytes(self) -> Result> { + let num_entries = self.entry_offsets.len(); + let num_rows = self.item_indices.len(); + let mut raw_bytes = Vec::with_capacity(num_entries * 8 + num_rows * 4 + 4); + + raw_bytes.write_u32::(num_rows as u32)?; + unsafe { + // safety: write integer arrays as raw bytes + raw_bytes.write_all(from_raw_parts( + self.entry_offsets.as_ptr() as *const u8, + num_entries * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.entry_lens.as_ptr() as *const u8, + num_entries * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.item_indices.as_ptr() as *const u8, + num_rows * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.item_hashes.as_ptr() as *const u8, + num_rows * 4, + ))?; + } + Ok(raw_bytes) + } + + pub fn entry<'a>(&'a self, hash: u32) -> Option + 'a> { + let entry = hash % (self.entry_offsets.len() as u32); + let len = self.entry_lens[entry as usize] as usize; + if len > 0 { + let offset = self.entry_offsets[entry as usize] as usize; + Some( + self.item_indices[offset..][..len] + .iter() + .cloned() + .filter(move |&idx| self.item_hashes[idx as usize] == hash), + ) + } else { + None + } + } + + fn num_entries_of_rows(num_rows: usize) -> usize { + num_rows * 3 + 1 + } +} + +pub struct JoinHashMap { + data_batch: RecordBatch, + key_columns: Vec, + table: Table, +} + +impl JoinHashMap { + pub fn try_from_data_batch( + data_batch: RecordBatch, + key_exprs: &[PhysicalExprRef], + ) -> Result { + let key_columns: Vec = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::>()?; + + let (table, data_batch) = + Table::try_from_key_columns(data_batch.num_rows(), data_batch, &key_columns)?; + Ok(JoinHashMap { + data_batch, + key_columns, + table, + }) + } + + pub fn try_from_hash_map_batch( + hash_map_batch: RecordBatch, + key_exprs: &[PhysicalExprRef], + ) -> Result { + let mut data_batch = hash_map_batch.clone(); + let table = Table::try_from_raw_bytes( + data_batch + .remove_column(data_batch.num_columns() - 1) + .as_binary::() + .value(0), + )?; + let key_columns: Vec = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::>()?; + Ok(Self { + data_batch, + key_columns, + table, + }) + } + + pub fn try_new_empty( + hash_map_schema: SchemaRef, + key_exprs: &[PhysicalExprRef], + ) -> Result { + let table = Table::new_empty(); + let data_batch = RecordBatch::new_empty(hash_map_schema); + let key_columns: Vec = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::>()?; + Ok(Self { + data_batch, + key_columns, + table, + }) + } + + pub fn data_schema(&self) -> SchemaRef { + self.data_batch().schema() + } + + pub fn data_batch(&self) -> &RecordBatch { + &self.data_batch + } + + pub fn key_columns(&self) -> &[ArrayRef] { + &self.key_columns + } + + pub fn entry_indices<'a>(&'a self, hash: u32) -> Option + 'a> { + self.table.entry(hash) + } + + pub fn into_hash_map_batch(self) -> Result { + let schema = join_hash_map_schema(&self.data_batch.schema()); + if self.data_batch.num_rows() == 0 { + return Ok(RecordBatch::new_empty(schema)); + } + let mut table_col_builder = BinaryBuilder::new(); + table_col_builder.append_value(&self.table.try_into_raw_bytes()?); + for _ in 1..self.data_batch.num_rows() { + table_col_builder.append_null(); + } + let table_col: ArrayRef = Arc::new(table_col_builder.finish()); + Ok(RecordBatch::try_new( + schema, + vec![self.data_batch.columns().to_vec(), vec![table_col]].concat(), + )?) + } +} + +#[inline] +pub fn join_data_schema(hash_map_schema: &SchemaRef) -> SchemaRef { + Arc::new(Schema::new( + hash_map_schema + .fields() + .iter() + .take(hash_map_schema.fields().len() - 1) // exclude hash map column + .cloned() + .collect::>(), + )) +} + +#[inline] +pub fn join_hash_map_schema(data_schema: &SchemaRef) -> SchemaRef { + Arc::new(Schema::new( + data_schema + .fields() + .iter() + .map(|field| Arc::new(field.as_ref().clone().with_nullable(true))) + .chain(std::iter::once(join_table_field())) + .collect::>(), + )) +} + +#[inline] +pub fn join_create_hashes(num_rows: usize, key_columns: &[ArrayRef]) -> Result> { + const JOIN_HASH_RANDOM_SEED: u32 = 0x90ec4058; + let mut hashes = vec![JOIN_HASH_RANDOM_SEED; num_rows]; + create_hashes(key_columns, &mut hashes)?; + Ok(hashes) +} + +#[inline] +fn join_table_field() -> FieldRef { + static BHJ_KEY_FIELD: OnceCell = OnceCell::new(); + BHJ_KEY_FIELD + .get_or_init(|| Arc::new(Field::new("~TABLE", DataType::Binary, true))) + .clone() +} diff --git a/native-engine/datafusion-ext-plans/src/joins/join_utils.rs b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs new file mode 100644 index 000000000..076cfa165 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs @@ -0,0 +1,64 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use datafusion::common::{DataFusionError, Result}; +use datafusion_ext_commons::df_execution_err; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + LeftAnti, + RightAnti, + LeftSemi, + RightSemi, + Existence, +} + +impl TryFrom for datafusion::prelude::JoinType { + type Error = DataFusionError; + + fn try_from(value: JoinType) -> Result { + match value { + JoinType::Inner => Ok(datafusion::prelude::JoinType::Inner), + JoinType::Left => Ok(datafusion::prelude::JoinType::Left), + JoinType::Right => Ok(datafusion::prelude::JoinType::Right), + JoinType::Full => Ok(datafusion::prelude::JoinType::Full), + JoinType::LeftAnti => Ok(datafusion::prelude::JoinType::LeftAnti), + JoinType::RightAnti => Ok(datafusion::prelude::JoinType::RightAnti), + JoinType::LeftSemi => Ok(datafusion::prelude::JoinType::LeftSemi), + JoinType::RightSemi => Ok(datafusion::prelude::JoinType::RightSemi), + other => df_execution_err!("unsupported join type: {other:?}"), + } + } +} + +impl TryFrom for JoinType { + type Error = DataFusionError; + + fn try_from(value: datafusion::prelude::JoinType) -> Result { + match value { + datafusion::prelude::JoinType::Inner => Ok(JoinType::Inner), + datafusion::prelude::JoinType::Left => Ok(JoinType::Left), + datafusion::prelude::JoinType::Right => Ok(JoinType::Right), + datafusion::prelude::JoinType::Full => Ok(JoinType::Full), + datafusion::prelude::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + datafusion::prelude::JoinType::RightAnti => Ok(JoinType::RightAnti), + datafusion::prelude::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + datafusion::prelude::JoinType::RightSemi => Ok(JoinType::RightSemi), + } + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/mod.rs b/native-engine/datafusion-ext-plans/src/joins/mod.rs new file mode 100644 index 000000000..3505a9a77 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/mod.rs @@ -0,0 +1,113 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::ArrayRef, + compute::SortOptions, + datatypes::{DataType, SchemaRef}, +}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; + +use crate::joins::{join_utils::JoinType, stream_cursor::StreamCursor}; + +pub mod join_hash_map; +pub mod join_utils; +pub mod stream_cursor; + +// join implementations +pub mod bhj; +pub mod smj; +mod test; + +#[derive(Debug, Clone)] +pub struct JoinParams { + pub join_type: JoinType, + pub left_schema: SchemaRef, + pub right_schema: SchemaRef, + pub output_schema: SchemaRef, + pub left_keys: Vec, + pub right_keys: Vec, + pub key_data_types: Vec, + pub sort_options: Vec, + pub projection: JoinProjection, + pub batch_size: usize, +} + +#[derive(Debug, Clone)] +pub struct JoinProjection { + pub schema: SchemaRef, + pub left_schema: SchemaRef, + pub right_schema: SchemaRef, + pub left: Vec, + pub right: Vec, +} + +impl JoinProjection { + pub fn try_new( + join_type: JoinType, + schema: &SchemaRef, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + projection: &[usize], + ) -> Result { + let projected_schema = Arc::new(schema.project(projection)?); + let mut left = vec![]; + let mut right = vec![]; + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + for &i in projection { + if i < left_schema.fields().len() { + left.push(i); + } else if i - left_schema.fields().len() < right_schema.fields().len() { + right.push(i - left_schema.fields().len()); + } + } + } + JoinType::LeftAnti | JoinType::LeftSemi => { + left = projection.to_vec(); + } + JoinType::RightAnti | JoinType::RightSemi => { + right = projection.to_vec(); + } + JoinType::Existence => { + for &i in projection { + if i < left_schema.fields().len() { + left.push(i); + } + } + } + } + Ok(Self { + schema: projected_schema, + left_schema: Arc::new(left_schema.project(&left)?), + right_schema: Arc::new(right_schema.project(&right)?), + left, + right, + }) + } + + pub fn project_left(&self, cols: &[ArrayRef]) -> Vec { + self.left.iter().map(|&i| cols[i].clone()).collect() + } + + pub fn project_right(&self, cols: &[ArrayRef]) -> Vec { + self.right.iter().map(|&i| cols[i].clone()).collect() + } +} + +pub type Idx = (usize, usize); +pub type StreamCursors = (StreamCursor, StreamCursor); diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs new file mode 100644 index 000000000..5749eb01b --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs @@ -0,0 +1,175 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{Idx, JoinParams, StreamCursors}, + sort_merge_join_exec::Joiner, +}; + +pub struct ExistenceJoiner { + join_params: JoinParams, + output_sender: Arc, + indices: Vec, + exists: Vec, + send_output_time: Time, + output_rows: usize, +} + +impl ExistenceJoiner { + pub fn new(join_params: JoinParams, output_sender: Arc) -> Self { + Self { + join_params, + output_sender, + indices: vec![], + exists: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.indices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_idx) = self.indices.first() { + if first_idx.0 < curs.0.cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let indices = std::mem::take(&mut self.indices); + let num_rows = indices.len(); + let cols = interleave_batches( + curs.0.projected_batch_schema.clone(), + &curs.0.projected_batches, + &indices, + )?; + + let exists = std::mem::take(&mut self.exists); + let exists_col: ArrayRef = Arc::new(arrow::array::BooleanArray::from(exists)); + + let output_batch = RecordBatch::try_new_with_options( + self.join_params.output_schema.clone(), + [cols.columns().to_vec(), vec![exists_col]].concat(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl Joiner for ExistenceJoiner { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + + match compare_cursor!(curs) { + Ordering::Less => { + self.indices.push(curs.0.cur_idx); + self.exists.push(false); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx)); + } + Ordering::Greater => { + cur_forward!(curs.1); + curs.1 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.1.cur_idx)); + } + Ordering::Equal => { + loop { + self.indices.push(lidx); + self.exists.push(true); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx)); + + if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) { + lidx = curs.0.cur_idx; + continue; + } + break; + } + + // skip all right equal rows + loop { + cur_forward!(curs.1); + curs.1.set_min_reserved_idx(ridx); + + if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) { + ridx = curs.1.cur_idx; + continue; + } + break; + } + } + } + } + + while !curs.0.finished { + self.indices.push(curs.0.cur_idx); + self.exists.push(false); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx)); + } + if !self.indices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs new file mode 100644 index 000000000..55967f457 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs @@ -0,0 +1,248 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; +use smallvec::{smallvec, SmallVec}; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{Idx, JoinParams, StreamCursors}, + sort_merge_join_exec::Joiner, +}; + +pub struct FullJoiner { + join_params: JoinParams, + output_sender: Arc, + lindices: Vec, + rindices: Vec, + send_output_time: Time, + output_rows: usize, +} + +pub type InnerJoiner = FullJoiner; +pub type LeftOuterJoiner = FullJoiner; +pub type RightOuterJoiner = FullJoiner; +pub type FullOuterJoiner = FullJoiner; + +impl FullJoiner { + pub fn new(join_params: JoinParams, output_sender: Arc) -> Self { + Self { + join_params, + output_sender, + lindices: vec![], + rindices: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.lindices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_lidx) = self.lindices.first() { + if first_lidx.0 < curs.0.cur_idx.0 { + return true; + } + } + if let Some(first_ridx) = self.rindices.first() { + if first_ridx.0 < curs.1.cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let lindices = std::mem::take(&mut self.lindices); + let rindices = std::mem::take(&mut self.rindices); + let num_rows = lindices.len(); + assert_eq!(lindices.len(), rindices.len()); + + let lcols = interleave_batches( + curs.0.projected_batch_schema.clone(), + &curs.0.projected_batches, + &lindices, + )?; + let rcols = interleave_batches( + curs.1.projected_batch_schema.clone(), + &curs.1.projected_batches, + &rindices, + )?; + let output_batch = RecordBatch::try_new_with_options( + self.join_params.projection.schema.clone(), + [lcols.columns(), rcols.columns()].concat(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl Joiner for FullJoiner { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + match compare_cursor!(curs) { + Ordering::Less => { + if L_OUTER { + self.lindices.push(lidx); + self.rindices.push(Idx::default()); + } + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&lidx)); + } + Ordering::Greater => { + if R_OUTER { + self.lindices.push(Idx::default()); + self.rindices.push(ridx); + } + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&ridx)); + } + Ordering::Equal => { + cur_forward!(curs.0); + cur_forward!(curs.1); + self.lindices.push(lidx); + self.rindices.push(ridx); + + let mut equal_lindices: SmallVec<[Idx; 16]> = smallvec![lidx]; + let mut equal_rindices: SmallVec<[Idx; 16]> = smallvec![ridx]; + let mut last_lidx = lidx; + let mut last_ridx = ridx; + lidx = curs.0.cur_idx; + ridx = curs.1.cur_idx; + let mut l_equal = !curs.0.finished && curs.0.key(lidx) == curs.0.key(last_lidx); + let mut r_equal = !curs.1.finished && curs.1.key(ridx) == curs.1.key(last_ridx); + + while l_equal || r_equal { + if l_equal { + for &ridx in &equal_rindices { + self.lindices.push(lidx); + self.rindices.push(ridx); + } + if r_equal { + equal_lindices.push(lidx); + } + cur_forward!(curs.0); + last_lidx = lidx; + lidx = curs.0.cur_idx; + } else { + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&last_ridx)); + } + + if r_equal { + for &lidx in &equal_lindices { + self.lindices.push(lidx); + self.rindices.push(ridx); + } + if l_equal { + equal_rindices.push(ridx); + } + cur_forward!(curs.1); + last_ridx = ridx; + ridx = curs.1.cur_idx; + } else { + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&last_lidx)); + } + + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + l_equal = l_equal + && !curs.0.finished + && curs.0.key(lidx) == curs.0.key(last_lidx); + r_equal = r_equal + && !curs.1.finished + && curs.1.key(ridx) == curs.1.key(last_ridx); + } + + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&curs.0.cur_idx)); + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&curs.1.cur_idx)); + } + } + } + + // at least one side is finished, consume the other side if it is an outer side + while L_OUTER && !curs.0.finished { + let lidx = curs.0.cur_idx; + self.lindices.push(lidx); + self.rindices.push(Idx::default()); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&lidx)); + } + while R_OUTER && !curs.1.finished { + let ridx = curs.1.cur_idx; + self.lindices.push(Idx::default()); + self.rindices.push(ridx); + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&ridx)); + } + if !self.lindices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs new file mode 100644 index 000000000..8bcdadff1 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs @@ -0,0 +1,17 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod existence_join; +pub mod full_join; +pub mod semi_join; diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs new file mode 100644 index 000000000..fd5f9351e --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs @@ -0,0 +1,252 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{ + smj::semi_join::SemiJoinSide::{L, R}, + Idx, JoinParams, StreamCursors, + }, + sort_merge_join_exec::Joiner, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum SemiJoinSide { + L, + R, +} + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + join_side: SemiJoinSide, + semi: bool, +} + +impl JoinerParams { + const fn new(join_side: SemiJoinSide, semi: bool) -> Self { + Self { join_side, semi } + } +} +pub struct SemiJoiner { + join_params: JoinParams, + output_sender: Arc, + indices: Vec, + send_output_time: Time, + output_rows: usize, +} + +const LEFT_SEMI: JoinerParams = JoinerParams::new(L, true); +const LEFT_ANTI: JoinerParams = JoinerParams::new(L, false); +const RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true); +const RIGHT_ANTI: JoinerParams = JoinerParams::new(R, false); + +pub type LeftSemiJoiner = SemiJoiner; +pub type LeftAntiJoiner = SemiJoiner; +pub type RightSemiJoiner = SemiJoiner; +pub type RightAntiJoiner = SemiJoiner; + +impl SemiJoiner

{ + pub fn new(join_params: JoinParams, output_sender: Arc) -> Self { + Self { + join_params, + output_sender, + indices: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.indices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_idx) = self.indices.first() { + let cur_idx = match P.join_side { + L => curs.0.cur_idx, + R => curs.1.cur_idx, + }; + if first_idx.0 < cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let indices = std::mem::take(&mut self.indices); + let num_rows = indices.len(); + + let cols = match P.join_side { + L => interleave_batches( + curs.0.projected_batch_schema.clone(), + &curs.0.projected_batches, + &indices, + )?, + R => interleave_batches( + curs.1.projected_batch_schema.clone(), + &curs.1.projected_batches, + &indices, + )?, + }; + let output_batch = RecordBatch::try_new_with_options( + self.join_params.projection.schema.clone(), + cols.columns().to_vec(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl Joiner for SemiJoiner

{ + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + + match compare_cursor!(curs) { + Ordering::Less => { + if P.join_side == L && !P.semi { + self.indices.push(lidx); + } + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + } + Ordering::Greater => { + if P.join_side == R && !P.semi { + self.indices.push(ridx); + } + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + } + Ordering::Equal => { + // output/skip left equal rows + loop { + if P.join_side == L && P.semi { + self.indices.push(lidx); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + } + cur_forward!(curs.0); + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + + if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) { + lidx = curs.0.cur_idx; + continue; + } + break; + } + + // output/skip right equal rows + loop { + if P.join_side == R && P.semi { + self.indices.push(ridx); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + } + cur_forward!(curs.1); + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + + if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) { + ridx = curs.1.cur_idx; + continue; + } + break; + } + } + } + } + + // at least one side is finished, consume the other side if it is an anti side + if !P.semi { + while P.join_side == L && !P.semi && !curs.0.finished { + let lidx = curs.0.cur_idx; + self.indices.push(lidx); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + } + while P.join_side == R && !P.semi && !curs.1.finished { + let ridx = curs.1.cur_idx; + self.indices.push(ridx); + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + } + } + if !self.indices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs new file mode 100644 index 000000000..c105bb8d7 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs @@ -0,0 +1,235 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::{RecordBatch, RecordBatchOptions}, + buffer::NullBuffer, + datatypes::{Schema, SchemaRef}, + row::{Row, RowConverter, Rows, SortField}, +}; +use datafusion::{ + common::{JoinSide, Result}, + execution::SendableRecordBatchStream, + physical_expr::PhysicalExprRef, + physical_plan::metrics::Time, +}; +use datafusion_ext_commons::array_size::ArraySize; +use futures::{Future, StreamExt}; +use parking_lot::Mutex; + +use crate::{ + common::batch_selection::take_batch_opt, + joins::{Idx, JoinParams}, +}; + +pub struct StreamCursor { + stream: SendableRecordBatchStream, + key_converter: Arc>, + key_exprs: Vec, + poll_time: Time, + + // IMPORTANT: + // batches/rows/null_buffers always contains a `null batch` in the front + projection: Vec, + pub projected_batch_schema: SchemaRef, + pub projected_batches: Vec, + pub cur_idx: Idx, + min_reserved_idx: Idx, + keys: Vec>, + key_has_nulls: Vec>, + num_null_batches: usize, + mem_size: usize, + pub finished: bool, +} + +impl StreamCursor { + pub fn try_new( + stream: SendableRecordBatchStream, + join_params: &JoinParams, + join_side: JoinSide, + projection: &[usize], + ) -> Result { + let key_converter = Arc::new(Mutex::new(RowConverter::new( + join_params + .key_data_types + .iter() + .cloned() + .zip(&join_params.sort_options) + .map(|(dt, options)| SortField::new_with_options(dt, *options)) + .collect(), + )?)); + let key_exprs = match join_side { + JoinSide::Left => join_params.left_keys.clone(), + JoinSide::Right => join_params.right_keys.clone(), + }; + + let empty_batch = RecordBatch::new_empty(Arc::new(Schema::new( + stream + .schema() + .fields() + .iter() + .map(|f| f.as_ref().clone().with_nullable(true)) + .collect::>(), + ))); + let empty_keys = Arc::new( + key_converter.lock().convert_columns( + &key_exprs + .iter() + .map(|key| Ok(key.evaluate(&empty_batch)?.into_array(0)?)) + .collect::>>()?, + )?, + ); + let null_batch = take_batch_opt(empty_batch, [Option::::None])?; + let projected_null_batch = null_batch.project(projection)?; + let null_nb = NullBuffer::new_null(1); + + Ok(Self { + stream, + key_exprs, + key_converter, + poll_time: Time::new(), + projection: projection.to_vec(), + projected_batch_schema: projected_null_batch.schema(), + projected_batches: vec![projected_null_batch], + cur_idx: (0, 0), + min_reserved_idx: (0, 0), + keys: vec![empty_keys], + key_has_nulls: vec![Some(null_nb)], + num_null_batches: 1, + mem_size: 0, + finished: false, + }) + } + + pub fn next(&mut self) -> Option> + '_> { + self.cur_idx.1 += 1; + if self.cur_idx.1 >= self.projected_batches[self.cur_idx.0].num_rows() { + self.cur_idx.0 += 1; + self.cur_idx.1 = 0; + } + + let should_load_next_batch = self.cur_idx.0 >= self.projected_batches.len(); + if should_load_next_batch { + Some(async move { + while let Some(batch) = { + let timer = self.poll_time.timer(); + let batch = self.stream.next().await.transpose()?; + drop(timer); + batch + } { + if batch.num_rows() == 0 { + continue; + } + let key_columns = self + .key_exprs + .iter() + .map(|key| Ok(key.evaluate(&batch)?.into_array(batch.num_rows())?)) + .collect::>>()?; + let key_has_nulls = key_columns + .iter() + .map(|c| c.nulls().cloned()) + .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref())) + .unwrap_or(None); + let keys = Arc::new(self.key_converter.lock().convert_columns(&key_columns)?); + + self.mem_size += batch.get_array_mem_size(); + self.mem_size += key_has_nulls + .as_ref() + .map(|nb| nb.buffer().len()) + .unwrap_or_default(); + self.mem_size += keys.size(); + + self.projected_batches + .push(RecordBatch::try_new_with_options( + self.projected_batches[0].schema(), + self.projection + .iter() + .map(|&i| batch.column(i).clone()) + .collect(), + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?); + self.key_has_nulls.push(key_has_nulls); + self.keys.push(keys); + + // fill out-dated batches with null batches + if self.num_null_batches < self.min_reserved_idx.0 { + for i in self.num_null_batches..self.min_reserved_idx.0 { + self.mem_size -= self.projected_batches[i].get_array_mem_size(); + self.mem_size -= self.key_has_nulls[i] + .as_ref() + .map(|nb| nb.buffer().len()) + .unwrap_or_default(); + self.mem_size -= self.keys[i].size(); + + self.projected_batches[i] = self.projected_batches[0].clone(); + self.keys[i] = self.keys[0].clone(); + self.key_has_nulls[i] = self.key_has_nulls[0].clone(); + self.num_null_batches += 1; + } + } + return Ok(()); + } + self.finished = true; + return Ok(()); + }) + } else { + None + } + } + + #[inline] + pub fn is_null_key(&self, idx: Idx) -> bool { + self.key_has_nulls[idx.0] + .as_ref() + .map(|nb| nb.is_null(idx.1)) + .unwrap_or(false) + } + + #[inline] + pub fn key<'a>(&'a self, idx: Idx) -> Row<'a> { + let keys = &self.keys[idx.0]; + keys.row(idx.1) + } + + #[inline] + pub fn num_buffered_batches(&self) -> usize { + self.projected_batches.len() - self.num_null_batches + } + + #[inline] + pub fn mem_size(&self) -> usize { + self.mem_size + } + + #[inline] + pub fn set_min_reserved_idx(&mut self, idx: Idx) { + self.min_reserved_idx = idx; + } + + #[inline] + pub fn total_poll_time(&self) -> usize { + self.poll_time.value() + } +} + +#[macro_export] +macro_rules! cur_forward { + ($cur:expr) => {{ + if let Some(fut) = $cur.next() { + fut.await?; + } + }}; +} diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs new file mode 100644 index 000000000..e0826e7d0 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/test.rs @@ -0,0 +1,947 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + self, + array::*, + compute::SortOptions, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_sorted_eq, + common::JoinSide, + error::Result, + physical_expr::expressions::Column, + physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + use TestType::*; + + use crate::{ + broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, + broadcast_join_exec::BroadcastJoinExec, + joins::join_utils::{JoinType, JoinType::*}, + sort_merge_join_exec::SortMergeJoinExec, + }; + + #[derive(Clone, Copy)] + enum TestType { + SMJ, + BHJLeftProbed, + BHJRightProbed, + } + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table_i32( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_table_from_batches(batches: Vec) -> Arc { + let schema = batches.first().unwrap().schema(); + Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + /// returns a table with 3 columns of i32 in memory + pub fn build_table_i32_nullable( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_join_schema_for_test( + left: &Schema, + right: &Schema, + join_type: JoinType, + ) -> Result { + if join_type == Existence { + let exists_field = Arc::new(Field::new("exists#0", DataType::Boolean, false)); + return Ok(Arc::new(Schema::new( + [left.fields().to_vec(), vec![exists_field]].concat(), + ))); + } + Ok(Arc::new( + build_join_schema(left, right, &join_type.try_into()?).0, + )) + } + + async fn join_collect( + test_type: TestType, + left: Arc, + right: Arc, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?; + + let join: Arc = match test_type { + SMJ => { + let sort_options = vec![SortOptions::default(); on.len()]; + Arc::new(SortMergeJoinExec::try_new( + schema, + left, + right, + on, + join_type, + sort_options, + )?) + } + BHJLeftProbed => { + let right = Arc::new(BroadcastJoinBuildHashMapExec::new( + right, + on.iter().map(|(_, right_key)| right_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Right, + None, + )?) + } + BHJRightProbed => { + let left = Arc::new(BroadcastJoinBuildHashMapExec::new( + left, + on.iter().map(|(left_key, _)| left_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Left, + None, + )?) + } + }; + let columns = columns(&join.schema()); + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_two_two() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 1, 2]), + ("b2", &vec![1, 1, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 1, 3]), + ("b2", &vec![1, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 1 | 1 | 7 | 1 | 1 | 80 |", + "| 1 | 1 | 8 | 1 | 1 | 70 |", + "| 1 | 1 | 8 | 1 | 1 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field + ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), + ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| | | | 30 | 6 | 90 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, LeftAnti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 5 is double on the right + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, LeftSemi).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_with_duplicated_column_names() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a", &vec![1, 2, 3]), + ("b", &vec![4, 5, 7]), + ("c", &vec![7, 8, 9]), + ); + let right = build_table( + ("a", &vec![10, 20, 30]), + ("b", &vec![1, 2, 7]), + ("c", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + // join on a=b so there are duplicate column names on unjoined columns + Arc::new(Column::new_with_schema("a", &left.schema())?), + Arc::new(Column::new_with_schema("b", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+---+---+---+----+---+----+", + "| a | b | c | a | b | c |", + "+---+---+---+----+---+----+", + "| 1 | 4 | 7 | 10 | 1 | 70 |", + "| 2 | 5 | 8 | 20 | 2 | 80 |", + "+---+---+---+----+---+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_date32() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19108, 19108]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19107, 19108, 19109]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + + let expected = vec![ + "+------------+------------+------------+------------+------------+------------+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+------------+------------+------------+------------+------------+------------+", + "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", + "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", + "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", + "+------------+------------+------------+------------+------------+------------+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_date64() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), /* this has a + * repetition */ + ("c1", &vec![7, 8, 9]), + ); + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", + "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", + "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + ]; + + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_sort_order() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![3, 4, 5, 6, 6, 7]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![2, 4, 6, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_sort_order() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![0, 1, 2, 3]), + ("b1", &vec![3, 4, 5, 7]), + ("c1", &vec![6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30]), + ("b2", &vec![2, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 60 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| | | | 30 | 6 | 90 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "| 6 | 9 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![3, 4, 5, 6]), + ("b2", &vec![6, 6, 7, 9]), + ("c2", &vec![7, 8, 9, 9]), + ); + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 10, 20]), + ("b1", &vec![2, 4, 6]), + ("c1", &vec![50, 60, 70]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![30, 40]), + ("b1", &vec![6, 8]), + ("c1", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 3 | 4 |", + "| 10 | 4 | 60 | 1 | 4 | 5 |", + "| | | | 2 | 5 | 6 |", + "| 20 | 6 | 70 | 3 | 6 | 7 |", + "| 30 | 6 | 80 | 3 | 6 | 7 |", + "| 20 | 6 | 70 | 4 | 6 | 8 |", + "| 30 | 6 | 80 | 4 | 6 | 8 |", + "| | | | 5 | 7 | 9 |", + "| | | | 6 | 9 | 9 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_full_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 50 |", + "| | | | 40 | 8 | 90 |", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "| 6 | 9 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_existence_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Existence).await?; + let expected = vec![ + "+----+----+----+----------+", + "| a1 | b1 | c1 | exists#0 |", + "+----+----+----+----------+", + "| 0 | 3 | 4 | false |", + "| 1 | 4 | 5 | true |", + "| 2 | 5 | 6 | false |", + "| 3 | 6 | 7 | true |", + "| 4 | 6 | 8 | true |", + "| 5 | 7 | 9 | false |", + "| 6 | 9 | 9 | false |", + "+----+----+----+----------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/lib.rs b/native-engine/datafusion-ext-plans/src/lib.rs index a0797fb09..a48fb56a2 100644 --- a/native-engine/datafusion-ext-plans/src/lib.rs +++ b/native-engine/datafusion-ext-plans/src/lib.rs @@ -13,32 +13,38 @@ // limitations under the License. #![feature(get_mut_unchecked)] -#![feature(io_error_other)] +#![feature(adt_const_params)] -pub mod agg; +// execution plan implementations pub mod agg_exec; +pub mod broadcast_join_build_hash_map_exec; pub mod broadcast_join_exec; -pub mod broadcast_nested_loop_join_exec; -pub mod common; pub mod debug_exec; pub mod empty_partitions_exec; pub mod expand_exec; pub mod ffi_reader_exec; pub mod filter_exec; -pub mod generate; pub mod generate_exec; pub mod ipc_reader_exec; pub mod ipc_writer_exec; pub mod limit_exec; -pub mod memmgr; pub mod parquet_exec; pub mod parquet_sink_exec; pub mod project_exec; pub mod rename_columns_exec; pub mod rss_shuffle_writer_exec; -mod shuffle; pub mod shuffle_writer_exec; pub mod sort_exec; pub mod sort_merge_join_exec; -pub mod window; pub mod window_exec; + +// memory management +pub mod memmgr; + +// helper modules +pub mod agg; +pub mod common; +pub mod generate; +pub mod joins; +mod shuffle; +pub mod window; diff --git a/native-engine/datafusion-ext-plans/src/parquet_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_exec.rs index 8fd5f57f6..f7c206273 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_exec.rs @@ -20,7 +20,7 @@ use std::{any::Any, fmt, fmt::Formatter, ops::Range, sync::Arc}; use arrow::{ - array::ArrayRef, + array::{Array, ArrayRef, AsArray, ListArray}, datatypes::{DataType, SchemaRef}, }; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; @@ -56,7 +56,6 @@ use datafusion::{ use datafusion_ext_commons::{ batch_size, df_execution_err, hadoop_fs::{FsDataInputStream, FsProvider}, - streams::coalesce_stream::CoalesceInput, }; use fmt::Debug; use futures::{future::BoxFuture, stream::once, FutureExt, StreamExt, TryStreamExt}; @@ -71,7 +70,61 @@ fn schema_adapter_cast_column( col: &ArrayRef, data_type: &DataType, ) -> Result { - datafusion_ext_commons::cast::cast_scan_input_array(col.as_ref(), data_type) + macro_rules! handle_decimal { + ($s:ident, $t:ident, $tnative:ty, $prec:expr, $scale:expr) => {{ + use arrow::{array::*, datatypes::*}; + type DecimalBuilder = paste::paste! {[<$t Builder>]}; + type IntType = paste::paste! {[<$s Type>]}; + + let col = col.as_primitive::(); + let mut decimal_builder = DecimalBuilder::new(); + for i in 0..col.len() { + if col.is_valid(i) { + decimal_builder.append_value(col.value(i) as $tnative); + } else { + decimal_builder.append_null(); + } + } + Ok(Arc::new( + decimal_builder + .finish() + .with_precision_and_scale($prec, $scale)?, + )) + }}; + } + match data_type { + DataType::Decimal128(prec, scale) => match col.data_type() { + DataType::Int8 => handle_decimal!(Int8, Decimal128, i128, *prec, *scale), + DataType::Int16 => handle_decimal!(Int16, Decimal128, i128, *prec, *scale), + DataType::Int32 => handle_decimal!(Int32, Decimal128, i128, *prec, *scale), + DataType::Int64 => handle_decimal!(Int64, Decimal128, i128, *prec, *scale), + DataType::Decimal128(p, s) if p == prec && s == scale => Ok(col.clone()), + _ => df_execution_err!( + "schema_adapter_cast_column unsupported type: {:?} => {:?}", + col.data_type(), + data_type, + ), + }, + DataType::List(to_field) => match col.data_type() { + DataType::List(_from_field) => { + let col = col.as_list::(); + let from_inner = col.values(); + let to_inner = schema_adapter_cast_column(from_inner, to_field.data_type())?; + Ok(Arc::new(ListArray::try_new( + to_field.clone(), + col.offsets().clone(), + to_inner, + col.nulls().cloned(), + )?)) + } + _ => df_execution_err!( + "schema_adapter_cast_column unsupported type: {:?} => {:?}", + col.data_type(), + data_type, + ), + }, + _ => datafusion_ext_commons::cast::cast_scan_input_array(col.as_ref(), data_type), + } } /// Execution plan for scanning one or more Parquet partitions @@ -231,6 +284,9 @@ impl ExecutionPlan for ParquetExec { None => (0..self.base_config.file_schema.fields().len()).collect(), }; + let page_filtering_enabled = conf::PARQUET_ENABLE_PAGE_FILTERING.value()?; + let bloom_filter_enabled = conf::PARQUET_ENABLE_BLOOM_FILTER.value()?; + let opener = ParquetOpener { partition_index, projection: Arc::from(projection), @@ -243,10 +299,10 @@ impl ExecutionPlan for ParquetExec { metadata_size_hint: None, metrics: self.metrics.clone(), parquet_file_reader_factory: Arc::new(FsReaderFactory::new(fs_provider)), - pushdown_filters: false, - reorder_filters: false, - enable_page_index: false, - enable_bloom_filter: false, + pushdown_filters: page_filtering_enabled, + reorder_filters: page_filtering_enabled, + enable_page_index: page_filtering_enabled, + enable_bloom_filter: bloom_filter_enabled, }; let baseline_metrics_cloned = baseline_metrics.clone(); @@ -274,7 +330,7 @@ impl ExecutionPlan for ParquetExec { }) .try_flatten(), )); - context.coalesce_with_default_batch_size(timed_stream, &baseline_metrics) + Ok(timed_stream) } fn metrics(&self) -> Option { diff --git a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs index f2dff1dbb..69b46cf77 100644 --- a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs +++ b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs @@ -35,7 +35,6 @@ use datafusion::{ SendableRecordBatchStream, Statistics, }, }; -use datafusion_ext_commons::df_execution_err; use futures::{Stream, StreamExt}; use crate::agg::AGG_BUF_COLUMN_NAME; @@ -56,7 +55,12 @@ impl RenameColumnsExec { let input_schema = input.schema(); let mut new_names = vec![]; - for (i, field) in input_schema.fields().iter().enumerate() { + for (i, field) in input_schema + .fields() + .iter() + .take(renamed_column_names.len()) + .enumerate() + { if field.name() != AGG_BUF_COLUMN_NAME { new_names.push(renamed_column_names[i].clone()); } else { @@ -64,11 +68,9 @@ impl RenameColumnsExec { break; } } - if new_names.len() != input_schema.fields().len() { - df_execution_err!( - "renamed_column_names length not matched with input schema, \ - renames: {renamed_column_names:?}, input schema: {input_schema}", - )?; + + while new_names.len() < input_schema.fields().len() { + new_names.push(input_schema.field(new_names.len()).name().clone()); } let renamed_column_names = new_names; let renamed_schema = Arc::new(Schema::new( diff --git a/native-engine/datafusion-ext-plans/src/sort_exec.rs b/native-engine/datafusion-ext-plans/src/sort_exec.rs index 33a2818e8..56a8ec6e8 100644 --- a/native-engine/datafusion-ext-plans/src/sort_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_exec.rs @@ -49,7 +49,6 @@ use datafusion_ext_commons::{ downcast_any, ds::loser_tree::{ComparableForLoserTree, LoserTree}, io::{read_len, read_one_batch, write_len, write_one_batch}, - staging_mem_size_for_partial_sort, streams::coalesce_stream::CoalesceInput, }; use futures::{lock::Mutex, stream::once, StreamExt, TryStreamExt}; @@ -59,7 +58,7 @@ use parking_lot::Mutex as SyncMutex; use crate::{ common::{ - batch_selection::interleave_batches, + batch_selection::{interleave_batches, take_batch}, batch_statisitcs::{stat_input, InputBatchStatistics}, column_pruning::ExecuteWithColumnPruning, output::{TaskOutputter, WrappedRecordBatchSender}, @@ -242,11 +241,9 @@ impl MemConsumer for ExternalSorter { #[derive(Default)] struct BufferedData { - staging_batches: Vec, sorted_key_stores: Vec>, sorted_key_stores_mem_used: usize, sorted_batches: Vec, - staging_mem_used: usize, sorted_batches_mem_used: usize, num_rows: usize, } @@ -271,34 +268,15 @@ impl BufferedData { } fn mem_used(&self) -> usize { - self.staging_mem_used + self.sorted_batches_mem_used + self.sorted_key_stores_mem_used + self.sorted_batches_mem_used + self.sorted_key_stores_mem_used } fn add_batch(&mut self, batch: RecordBatch, sorter: &ExternalSorter) -> Result<()> { self.num_rows += batch.num_rows(); - self.staging_mem_used += batch.get_array_mem_size(); - self.staging_batches.push(batch); - if self.staging_mem_used >= staging_mem_size_for_partial_sort() { - self.flush_staging_batches(sorter)?; - } - Ok(()) - } - - fn flush_staging_batches(&mut self, sorter: &ExternalSorter) -> Result<()> { - let staging_batches = std::mem::take(&mut self.staging_batches); - self.staging_mem_used = 0; - - let schema = sorter.prune_sort_keys_from_batch.pruned_schema.clone(); - let (key_rows, batches): (Vec, Vec) = staging_batches - .into_iter() - .map(|batch| sorter.prune_sort_keys_from_batch.prune(batch)) - .collect::>>()? - .into_iter() - .unzip(); + let (key_rows, batch) = sorter.prune_sort_keys_from_batch.prune(batch)?; // sort the batch and append to sorter - let mut sorted_key_store = - Vec::with_capacity(key_rows.iter().map(|rows| rows.size()).sum::()); + let mut sorted_key_store = Vec::with_capacity(key_rows.size()); let mut key_writer = SortedKeysWriter::default(); let mut num_rows = 0; let sorted_batch; @@ -307,32 +285,28 @@ impl BufferedData { let cur_sorted_indices = key_rows .iter() .enumerate() - .flat_map(|(batch_idx, rows)| { - rows.iter() - .map(|key| unsafe { - // safety: keys have the same lifetime with key_rows - std::mem::transmute::<_, &'static [u8]>(key.as_ref()) - }) - .enumerate() - .map(move |(row_idx, key)| (key, batch_idx as u32, row_idx as u32)) + .map(|(row_idx, key)| { + let key = unsafe { + // safety: keys have the same lifetime with key_rows + std::mem::transmute::<_, &'static [u8]>(key.as_ref()) + }; + (key, row_idx as u32) }) .sorted_unstable_by_key(|&(key, ..)| key) .take(sorter.limit) - .map(|(key, batch_idx, row_idx)| { + .map(|(key, row_idx)| { num_rows += 1; key_writer.write_key(key, &mut sorted_key_store).unwrap(); - (batch_idx as usize, row_idx as usize) + row_idx as usize }) .collect::>(); - sorted_batch = interleave_batches(schema, &batches, &cur_sorted_indices)?; + sorted_batch = take_batch(batch, cur_sorted_indices)?; } else { key_rows .iter() - .flat_map(|rows| { - rows.iter().map(|key| unsafe { - // safety: keys have the same lifetime with key_rows - std::mem::transmute::<_, &'static [u8]>(key.as_ref()) - }) + .map(|key| unsafe { + // safety: keys have the same lifetime with key_rows + std::mem::transmute::<_, &'static [u8]>(key.as_ref()) }) .sorted_unstable() .take(sorter.limit) @@ -351,13 +325,10 @@ impl BufferedData { } fn into_sorted_batches<'a, KC: KeyCollector>( - mut self, + self, batch_size: usize, sorter: &ExternalSorter, ) -> Result> { - if !self.staging_batches.is_empty() { - self.flush_staging_batches(sorter)?; - } struct Cursor { idx: usize, row_idx: usize, diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 8459d476d..d4e5cda37 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs @@ -12,135 +12,151 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{any::Any, cmp::Ordering, fmt::Formatter, sync::Arc}; - -use arrow::{ - array::*, - buffer::NullBuffer, - compute::{prep_null_mask_filter, SortOptions}, - datatypes::{DataType, Schema, SchemaRef}, - record_batch::{RecordBatch, RecordBatchOptions}, - row::{Row, RowConverter, Rows, SortField}, +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + time::{Duration, Instant}, }; + +use arrow::{compute::SortOptions, datatypes::SchemaRef}; +use async_trait::async_trait; use datafusion::{ - common::JoinSide, + common::{DataFusionError, JoinSide}, error::Result, execution::context::TaskContext, - logical_expr::{JoinType, JoinType::*}, - physical_expr::{expressions::Column, PhysicalSortExpr}, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ - joins::utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter, JoinOn}, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, ScopedTimerGuard}, + joins::utils::JoinOn, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, }; use datafusion_ext_commons::{ - array_size::ArraySize, batch_size, df_execution_err, downcast_any, - streams::coalesce_stream::CoalesceInput, suggested_output_batch_mem_size, + batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput, }; -use futures::{StreamExt, TryStreamExt}; -use parking_lot::Mutex as SyncMutex; +use futures::TryStreamExt; -use crate::common::{ - batch_selection::{interleave_batches, take_batch_opt}, - column_pruning::ExecuteWithColumnPruning, - output::{TaskOutputter, WrappedRecordBatchSender}, +use crate::{ + common::{ + column_pruning::ExecuteWithColumnPruning, + output::{TaskOutputter, WrappedRecordBatchSender}, + }, + cur_forward, + joins::{ + join_utils::{JoinType, JoinType::*}, + smj::{ + existence_join::ExistenceJoiner, + full_join::{FullOuterJoiner, InnerJoiner, LeftOuterJoiner, RightOuterJoiner}, + semi_join::{LeftAntiJoiner, LeftSemiJoiner, RightAntiJoiner, RightSemiJoiner}, + }, + stream_cursor::StreamCursor, + JoinParams, JoinProjection, StreamCursors, + }, }; #[derive(Debug)] pub struct SortMergeJoinExec { - /// Left sorted joining execution plan left: Arc, - /// Right sorting joining execution plan right: Arc, - /// Set of common columns used to join on on: JoinOn, - /// How the join is performed join_type: JoinType, - /// Optional filter before outputting - join_filter: Option, - /// The schema once the join is applied + sort_options: Vec, schema: SchemaRef, - /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Sort options of join columns used in sorting left and right execution - /// plans - sort_options: Vec, } impl SortMergeJoinExec { pub fn try_new( + schema: SchemaRef, left: Arc, right: Arc, on: JoinOn, join_type: JoinType, - join_filter: Option, sort_options: Vec, ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - - if matches!(join_type, LeftSemi | LeftAnti | RightSemi | RightAnti,) { - if join_filter.is_some() { - df_execution_err!("Semi/Anti join with filter is not supported yet")?; - } - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - if sort_options.len() != on.len() { - df_execution_err!( - "Expected number of sort options: {}, actual: {}", - on.len(), - sort_options.len(), - )?; - } - - let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); Ok(Self { + schema, left, right, on, join_type, - join_filter, - schema, - metrics: ExecutionPlanMetricsSet::new(), sort_options, + metrics: ExecutionPlanMetricsSet::new(), }) } - fn create_join_params(&self, batch_size: usize) -> JoinParams { - let on_left: Vec = self + fn create_join_params(&self, projection: &[usize]) -> Result { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + let (left_keys, right_keys): (Vec, Vec) = + self.on.iter().cloned().unzip(); + let key_data_types = self .on .iter() - .map(|on| downcast_any!(on.0, Column).unwrap().index()) - .collect(); - let on_right: Vec = self - .on - .iter() - .map(|on| downcast_any!(on.1, Column).unwrap().index()) - .collect(); - let on_data_types = on_left - .iter() - .map(|&i| self.left.schema().field(i).data_type().clone()) - .collect::>(); - let sub_batch_size = batch_size / batch_size.ilog10() as usize; + .map(|(left_key, right_key)| { + Ok({ + let left_dt = left_key.data_type(&left_schema)?; + let right_dt = right_key.data_type(&right_schema)?; + if left_dt != right_dt { + df_execution_err!( + "join key data type differs {left_dt:?} <-> {right_dt:?}" + )?; + } + left_dt + }) + }) + .collect::>()?; - // use smaller batch size and coalesce batches at the end, to avoid buffer - // overflowing - JoinParams { + let projection = JoinProjection::try_new( + self.join_type, + &self.schema, + &left_schema, + &right_schema, + projection, + )?; + Ok(JoinParams { join_type: self.join_type, + left_schema, + right_schema, output_schema: self.schema(), - on_left, - on_right, - on_data_types, - join_filter: self.join_filter.clone(), + left_keys, + right_keys, + key_data_types, sort_options: self.sort_options.clone(), - batch_size: sub_batch_size, - left_output_projection: (0..self.left.schema().fields().len()).collect(), - right_output_projection: (0..self.right.schema().fields().len()).collect(), - } + projection, + batch_size: batch_size(), + }) + } + + fn execute_with_projection( + &self, + partition: usize, + context: Arc, + projection: Vec, + ) -> Result { + let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); + let join_params = self.create_join_params(&projection)?; + let left = self.left.execute(partition, context.clone())?; + let right = self.right.execute(partition, context.clone())?; + + let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); + let output_stream = Box::pin(RecordBatchStreamAdapter::new( + join_params.projection.schema.clone(), + futures::stream::once(async move { + context_cloned.output_with_sender( + "SortMergeJoin", + join_params.projection.schema.clone(), + move |sender| execute_join(left, right, join_params, metrics_cloned, sender), + ) + }) + .try_flatten(), + )); + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) } } @@ -154,6 +170,17 @@ impl DisplayAs for SortMergeJoinExec { } } +impl ExecuteWithColumnPruning for SortMergeJoinExec { + fn execute_projected( + &self, + partition: usize, + context: Arc, + projection: &[usize], + ) -> Result { + self.execute_with_projection(partition, context, projection.to_vec()) + } +} + impl ExecutionPlan for SortMergeJoinExec { fn as_any(&self) -> &dyn Any { self @@ -169,7 +196,7 @@ impl ExecutionPlan for SortMergeJoinExec { fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { match self.join_type { - Left | LeftSemi | LeftAnti => self.left.output_ordering(), + Left | LeftSemi | LeftAnti | Existence => self.left.output_ordering(), Right | RightSemi | RightAnti => self.right.output_ordering(), Inner => self.left.output_ordering(), Full => None, @@ -185,11 +212,11 @@ impl ExecutionPlan for SortMergeJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(SortMergeJoinExec::try_new( + self.schema(), children[0].clone(), children[1].clone(), self.on.clone(), self.join_type, - self.join_filter.clone(), self.sort_options.clone(), )?)) } @@ -199,12 +226,8 @@ impl ExecutionPlan for SortMergeJoinExec { partition: usize, context: Arc, ) -> Result { - let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); - let batch_size = batch_size(); - let join_params = self.create_join_params(batch_size); - let left = self.left.execute(partition, context.clone())?; - let right = self.right.execute(partition, context.clone())?; - execute_with_join_params(context, join_params, left, right, metrics) + let projection = (0..self.schema.fields().len()).collect(); + self.execute_with_projection(partition, context, projection) } fn metrics(&self) -> Option { @@ -216,1549 +239,76 @@ impl ExecutionPlan for SortMergeJoinExec { } } -impl ExecuteWithColumnPruning for SortMergeJoinExec { - fn execute_projected( - &self, - partition: usize, - context: Arc, - projection: &[usize], - ) -> Result { - let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); - let batch_size = batch_size(); - - let (join_params, left_projection, right_projection) = - self.create_join_params(batch_size).project(projection)?; - let left = self - .left - .execute_projected(partition, context.clone(), &left_projection)?; - let right = self - .right - .execute_projected(partition, context.clone(), &right_projection)?; - execute_with_join_params(context, join_params, left, right, metrics) - } -} - -#[derive(Clone)] -struct JoinParams { - join_type: JoinType, - output_schema: SchemaRef, - on_left: Vec, - on_right: Vec, - on_data_types: Vec, - sort_options: Vec, - join_filter: Option, - left_output_projection: Vec, - right_output_projection: Vec, - batch_size: usize, -} - -impl JoinParams { - fn project(&self, projection: &[usize]) -> Result<(Self, Vec, Vec)> { - let num_left_fields = self.left_output_projection.len(); - let mut left_projection = vec![]; - let mut right_projection = vec![]; - - for &i in projection { - match self.join_type { - Inner | Left | Right | Full => { - if i < num_left_fields { - left_projection.push(i); - } else { - right_projection.push(i - num_left_fields); - } - } - LeftSemi | LeftAnti => { - left_projection.push(i); - } - RightSemi | RightAnti => { - right_projection.push(i); - } - } - } - let num_left_output_columns = left_projection.len(); - let num_right_output_columns = right_projection.len(); - - let mut on_left_projected = vec![]; - let mut on_right_projected = vec![]; - for &l in &self.on_left { - on_left_projected.push(left_projection.iter().position(|&i| i == l).unwrap_or_else( - || { - left_projection.push(l); - left_projection.len() - 1 - }, - )); - } - for &r in &self.on_right { - on_right_projected.push( - right_projection - .iter() - .position(|&i| i == r) - .unwrap_or_else(|| { - right_projection.push(r); - right_projection.len() - 1 - }), - ); - } - - let mut join_filter_projected = None; - if let Some(join_filter) = &self.join_filter { - join_filter_projected = Some(JoinFilter::new( - join_filter.expression().clone(), - join_filter - .column_indices() - .iter() - .map(|ci| { - let projected_index = match ci.side { - JoinSide::Left => left_projection - .iter() - .position(|&i| i == ci.index) - .unwrap_or_else(|| { - left_projection.push(ci.index); - left_projection.len() - 1 - }), - JoinSide::Right => right_projection - .iter() - .position(|&i| i == ci.index) - .unwrap_or_else(|| { - right_projection.push(ci.index); - right_projection.len() - 1 - }), - }; - ColumnIndex { - index: projected_index, - side: ci.side, - } - }) - .collect(), - join_filter.schema().clone(), - )); - } - - let projected = Self { - join_type: self.join_type, - output_schema: Arc::new(self.output_schema.project(projection)?), - on_left: on_left_projected, - on_right: on_right_projected, - on_data_types: self.on_data_types.clone(), - sort_options: self.sort_options.clone(), - join_filter: join_filter_projected, - batch_size: self.batch_size, - left_output_projection: (0..num_left_output_columns).collect(), - right_output_projection: (0..num_right_output_columns).collect(), - }; - Ok((projected, left_projection, right_projection)) - } -} - -fn execute_with_join_params( - context: Arc, - join_params: JoinParams, - left: SendableRecordBatchStream, - right: SendableRecordBatchStream, - metrics: Arc, -) -> Result { - let metrics_cloned = metrics.clone(); - let context_cloned = context.clone(); - let output_schema = join_params.output_schema.clone(); - let output_stream = Box::pin(RecordBatchStreamAdapter::new( - join_params.output_schema.clone(), - futures::stream::once(async move { - context_cloned.output_with_sender("SortMergeJoin", output_schema, move |sender| { - execute_join(left, right, join_params, metrics_cloned, sender) - }) - }) - .try_flatten(), - )); - Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) -} - -async fn execute_join( +pub async fn execute_join( lstream: SendableRecordBatchStream, rstream: SendableRecordBatchStream, join_params: JoinParams, metrics: Arc, sender: Arc, ) -> Result<()> { - let elapsed_time = metrics.elapsed_compute().clone(); - let mut timer = elapsed_time.timer(); - - let on_row_converter = Arc::new(SyncMutex::new(RowConverter::new( - join_params - .on_data_types - .iter() - .zip(&join_params.sort_options) - .map(|(data_type, sort_option)| { - SortField::new_with_options(data_type.clone(), *sort_option) - }) - .collect(), - )?)); - - let mut lcur = StreamCursor::try_new( - lstream, - on_row_converter.clone(), - join_params.on_left.clone(), - join_params.left_output_projection.clone(), - )?; - let mut rcur = StreamCursor::try_new( - rstream, - on_row_converter.clone(), - join_params.on_right.clone(), - join_params.right_output_projection.clone(), + let start_time = Instant::now(); + + let mut curs = ( + StreamCursor::try_new( + lstream, + &join_params, + JoinSide::Left, + &join_params.projection.left, + )?, + StreamCursor::try_new( + rstream, + &join_params, + JoinSide::Right, + &join_params.projection.right, + )?, + ); + + // start first batches of both side asynchronously + tokio::try_join!( + async { Ok::<_, DataFusionError>(cur_forward!(curs.0)) }, + async { Ok::<_, DataFusionError>(cur_forward!(curs.1)) }, )?; - macro_rules! forward { - ($cur:expr) => {{ - if $cur.next() == NextAction::LoadNextBatch { - $cur.next_batch(&mut timer).await?; - } - }}; - } - - // load first record - forward!(lcur); - forward!(rcur); - let join_type = join_params.join_type; - let mut joiner = Joiner::new(); - let mut leqs = vec![]; - let mut reqs = vec![]; - - macro_rules! joiner_accept_pair { - ($lidx:expr, $ridx:expr) => {{ - let lidx = $lidx; - let ridx = $ridx; - let r = joiner.accept_pair(&join_params, &mut lcur, &mut rcur, lidx, ridx)?; - if let Some(batch) = r { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - }}; - } - - // process records until one side is exhausted - while !lcur.finished && !rcur.finished { - let r = compare_cursor(&lcur, lcur.cur_idx, &rcur, rcur.cur_idx); - match r { - Ordering::Less => { - if matches!(join_type, Left | LeftAnti | Full) { - joiner_accept_pair!(Some(lcur.cur_idx), None); - } - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - Ordering::Greater => { - if matches!(join_type, Right | RightAnti | Full) { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - } - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - Ordering::Equal => { - let lidx0 = lcur.cur_idx; - let ridx0 = rcur.cur_idx; - leqs.push(lidx0); - reqs.push(ridx0); - forward!(lcur); - forward!(rcur); - - let mut leq = true; - let mut req = true; - while leq && req { - if leq && !lcur.finished && lcur.row(lcur.cur_idx) == lcur.row(lidx0) { - leqs.push(lcur.cur_idx); - forward!(lcur); - } else { - leq = false; - } - if req && !rcur.finished && rcur.row(rcur.cur_idx) == rcur.row(ridx0) { - reqs.push(rcur.cur_idx); - forward!(rcur); - } else { - req = false; - } - } - - match join_type { - Inner | Left | Right | Full => { - for &l in &leqs { - for &r in &reqs { - joiner_accept_pair!(Some(l), Some(r)); - } - } - } - LeftSemi => { - for &l in &leqs { - joiner_accept_pair!(Some(l), None); - } - } - RightSemi => { - for &r in &reqs { - joiner_accept_pair!(None, Some(r)); - } - } - LeftAnti | RightAnti => {} - } - - if leq { - while !lcur.finished && lcur.row(lcur.cur_idx) == rcur.row(ridx0) { - match join_type { - Inner | Left | Right | Full => { - for &r in &reqs { - joiner_accept_pair!(Some(lcur.cur_idx), Some(r)); - } - } - LeftSemi => { - joiner_accept_pair!(Some(lcur.cur_idx), None); - } - RightSemi | LeftAnti | RightAnti => {} - } - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - } - if req { - while !rcur.finished && rcur.row(rcur.cur_idx) == lcur.row(lidx0) { - match join_type { - Inner | Left | Right | Full => { - for &l in &leqs { - joiner_accept_pair!(Some(l), Some(rcur.cur_idx)); - } - } - RightSemi => { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - } - LeftSemi | LeftAnti | RightAnti => {} - } - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - leqs.clear(); - reqs.clear(); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - - // flush joiner if cursors buffered too many batches - if !joiner.is_empty() && (lcur.num_buffered_batches() + rcur.num_buffered_batches() > 5) - || (lcur.mem_size() + rcur.mem_size() > suggested_output_batch_mem_size() - && lcur.num_buffered_batches() > 1 - && rcur.num_buffered_batches() > 1) - { - if let Some(batch) = joiner.flush_pairs(&join_params, &mut lcur, &mut rcur)? { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - } - } - - // process rest records in inexhausted side - if matches!(join_type, Left | LeftAnti | Full) { - while !lcur.finished { - joiner_accept_pair!(Some(lcur.cur_idx), None); - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - } - if matches!(join_type, Right | RightAnti | Full) { - while !rcur.finished { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - - // flush joiner - if !joiner.is_empty() { - if let Some(batch) = joiner.flush_pairs(&join_params, &mut lcur, &mut rcur)? { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - } + let mut joiner: Pin> = match join_type { + Inner => Box::pin(InnerJoiner::new(join_params, sender)), + Left => Box::pin(LeftOuterJoiner::new(join_params, sender)), + Right => Box::pin(RightOuterJoiner::new(join_params, sender)), + Full => Box::pin(FullOuterJoiner::new(join_params, sender)), + LeftSemi => Box::pin(LeftSemiJoiner::new(join_params, sender)), + RightSemi => Box::pin(RightSemiJoiner::new(join_params, sender)), + LeftAnti => Box::pin(LeftAntiJoiner::new(join_params, sender)), + RightAnti => Box::pin(RightAntiJoiner::new(join_params, sender)), + Existence => Box::pin(ExistenceJoiner::new(join_params, sender)), + }; + joiner.as_mut().join(&mut curs).await?; + metrics.record_output(joiner.num_output_rows()); + + // discount poll input and send output batch time + let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64; + join_time_ns -= joiner.total_send_output_time() as u64; + join_time_ns -= curs.0.total_poll_time() as u64; + join_time_ns -= curs.1.total_poll_time() as u64; + metrics + .elapsed_compute() + .add_duration(Duration::from_nanos(join_time_ns)); Ok(()) } -struct StreamCursor { - stream: SendableRecordBatchStream, - on_row_converter: Arc>, - on_columns: Vec, - - // IMPORTANT: - // batches/rows/null_buffers always contains a `null batch` in the front - batches: Vec, - projected_batches: Vec, - projection: Vec, - on_rows: Vec>, - on_row_null_buffers: Vec>, - cur_idx: (usize, usize), - num_null_batches: usize, - mem_size: usize, - finished: bool, -} - -#[derive(Clone, Copy, PartialEq, Eq)] -enum NextAction { - None, - LoadNextBatch, -} - -impl StreamCursor { - fn try_new( - stream: SendableRecordBatchStream, - on_row_converter: Arc>, - on_columns: Vec, - projection: Vec, - ) -> Result { - let empty_batch = RecordBatch::new_empty(Arc::new(Schema::new( - stream - .schema() - .fields() - .iter() - .map(|f| f.as_ref().clone().with_nullable(true)) - .collect::>(), - ))); - let null_batch = take_batch_opt(empty_batch, [Option::::None])?; - let null_on_rows = Arc::new( - on_row_converter - .lock() - .convert_columns(null_batch.project(&on_columns)?.columns())?, - ); - let null_nb = NullBuffer::new_null(1); - - Ok(Self { - stream, - on_row_converter, - on_columns, - projected_batches: vec![null_batch.project(&projection)?], - batches: vec![null_batch], - projection, - on_rows: vec![null_on_rows], - on_row_null_buffers: vec![Some(null_nb)], - cur_idx: (0, 0), - num_null_batches: 1, - mem_size: 0, - finished: false, - }) - } - - fn next(&mut self) -> NextAction { - let mut next_action = NextAction::None; - let mut cur_idx = self.cur_idx; - - if cur_idx.1 + 1 < self.batches[cur_idx.0].num_rows() { - cur_idx.1 += 1; - } else { - cur_idx.0 += 1; - cur_idx.1 = 0; - next_action = NextAction::LoadNextBatch; - } - self.cur_idx = cur_idx; - next_action - } - - async fn next_batch(&mut self, stop_timer: &mut ScopedTimerGuard<'_>) -> Result { - stop_timer.stop(); - if let Some(batch) = self.stream.next().await.transpose()? { - stop_timer.restart(); - let on_columns = batch.project(&self.on_columns)?.columns().to_vec(); - let on_row_null_buffer = on_columns - .iter() - .map(|c| c.nulls().cloned()) - .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref())) - .unwrap_or(None); - let on_rows = Arc::new(self.on_row_converter.lock().convert_columns(&on_columns)?); - - self.mem_size += batch.get_array_mem_size(); - self.mem_size += on_row_null_buffer - .as_ref() - .map(|nb| nb.buffer().len()) - .unwrap_or_default(); - self.mem_size += on_rows.size(); - - self.projected_batches - .push(batch.project(&self.projection)?); - self.batches.push(batch); - self.on_row_null_buffers.push(on_row_null_buffer); - self.on_rows.push(on_rows); - return Ok(true); - } else { - stop_timer.restart(); +#[macro_export] +macro_rules! compare_cursor { + ($curs:expr) => {{ + match ($curs.0.cur_idx, $curs.1.cur_idx) { + (lidx, _) if $curs.0.is_null_key(lidx) => Ordering::Less, + (_, ridx) if $curs.1.is_null_key(ridx) => Ordering::Greater, + (lidx, ridx) => $curs.0.key(lidx).cmp(&$curs.1.key(ridx)), } - self.finished = true; - Ok(false) - } - - #[inline] - fn row<'a>(&'a self, idx: (usize, usize)) -> Row<'a> { - let bidx = idx.0; - let ridx = idx.1; - self.on_rows[bidx].row(ridx) - } - - #[inline] - fn num_buffered_batches(&self) -> usize { - self.batches.len() - self.num_null_batches - } - - #[inline] - fn mem_size(&self) -> usize { - self.mem_size - } - - #[inline] - fn clear_outdated(&mut self, min_reserved_bidx: usize) { - // fill out-dated batches with null batches - for i in self.num_null_batches..min_reserved_bidx.min(self.cur_idx.0) { - self.mem_size -= self.batches[i].get_array_mem_size(); - self.mem_size -= self.on_row_null_buffers[i] - .as_ref() - .map(|nb| nb.buffer().len()) - .unwrap_or_default(); - self.mem_size -= self.on_rows[i].size(); - - self.projected_batches[i] = self.projected_batches[0].clone(); - self.batches[i] = self.batches[0].clone(); - self.on_rows[i] = self.on_rows[0].clone(); - self.on_row_null_buffers[i] = self.on_row_null_buffers[0].clone(); - self.num_null_batches += 1; - } - } -} - -#[derive(Default)] -struct Joiner { - ljoins: Vec<(usize, usize)>, - rjoins: Vec<(usize, usize)>, - l_min_reserved_bidx: usize, - r_min_reserved_bidx: usize, -} - -impl Joiner { - fn new() -> Self { - Self { - ljoins: vec![], - rjoins: vec![], - l_min_reserved_bidx: usize::MAX, - r_min_reserved_bidx: usize::MAX, - } - } - - fn accept_pair( - &mut self, - join_params: &JoinParams, - lcur: &mut StreamCursor, - rcur: &mut StreamCursor, - l: Option<(usize, usize)>, - r: Option<(usize, usize)>, - ) -> Result> { - if let Some((bidx, ridx)) = l { - self.ljoins.push((bidx, ridx)); - self.l_min_reserved_bidx = self.l_min_reserved_bidx.min(bidx); - } else { - self.ljoins.push((0, 0)); - } - - if let Some((bidx, ridx)) = r { - self.rjoins.push((bidx, ridx)); - self.r_min_reserved_bidx = self.r_min_reserved_bidx.min(bidx); - } else { - self.rjoins.push((0, 0)); - } - - let batch_size = join_params.batch_size; - if self.ljoins.len() >= batch_size || self.rjoins.len() >= batch_size { - return self.flush_pairs(join_params, lcur, rcur); - } - Ok(None) - } - - fn is_empty(&self) -> bool { - self.ljoins.is_empty() && self.rjoins.is_empty() - } - - fn flush_pairs( - &mut self, - join_params: &JoinParams, - lcur: &mut StreamCursor, - rcur: &mut StreamCursor, - ) -> Result> { - self.l_min_reserved_bidx = usize::MAX; - self.r_min_reserved_bidx = usize::MAX; - - if let Some(join_filter) = &join_params.join_filter { - let num_intermediate_rows = std::cmp::max(self.ljoins.len(), self.rjoins.len()); - - // get intermediate batch - let intermediate_columns = join_filter - .column_indices() - .iter() - .map(|ci| { - let (cur, joins) = match ci.side { - JoinSide::Left => (&lcur, &self.ljoins), - JoinSide::Right => (&rcur, &self.rjoins), - }; - let arrays = cur - .batches - .iter() - .map(|b| b.column(ci.index).as_ref()) - .collect::>(); - Ok(arrow::compute::interleave(&arrays, joins)?) - }) - .collect::>>()?; - - let intermediate_batch = RecordBatch::try_new_with_options( - Arc::new(join_filter.schema().clone()), - intermediate_columns, - &RecordBatchOptions::new().with_row_count(Some(num_intermediate_rows)), - )?; - - // evalute filter - let filtered_array = join_filter - .expression() - .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows())?; - let filtered = as_boolean_array(&filtered_array); - let filtered = if filtered.null_count() > 0 { - prep_null_mask_filter(filtered) - } else { - filtered.clone() - }; - - // apply filter - let mut retained = 0; - for (i, selected) in filtered.values().iter().enumerate() { - if selected { - self.ljoins[retained] = self.ljoins[i]; - self.rjoins[retained] = self.rjoins[i]; - retained += 1; - } - } - self.ljoins.truncate(retained); - self.rjoins.truncate(retained); - if retained == 0 { - return Ok(None); - } - } - - let lcols = || -> Result> { - Ok(if !lcur.projection.is_empty() { - interleave_batches( - lcur.projected_batches[0].schema(), - &lcur.projected_batches, - &self.ljoins, - )? - .columns() - .to_vec() - } else { - vec![] - }) - }; - let rcols = || -> Result> { - Ok(if !rcur.projection.is_empty() { - interleave_batches( - rcur.projected_batches[0].schema(), - &rcur.projected_batches, - &self.rjoins, - )? - .columns() - .to_vec() - } else { - vec![] - }) - }; - - let output_columns = match join_params.join_type { - LeftSemi | LeftAnti => lcols()?, - RightSemi | RightAnti => rcols()?, - _ => [lcols()?, rcols()?].concat(), - }; - let num_output_records = std::cmp::max(self.ljoins.len(), self.rjoins.len()); - self.ljoins.clear(); - self.rjoins.clear(); - let batch = RecordBatch::try_new_with_options( - join_params.output_schema.clone(), - output_columns, - &RecordBatchOptions::new().with_row_count(Some(num_output_records)), - )?; - Ok(Some(batch)) - } -} - -fn compare_cursor( - lcur: &StreamCursor, - lidx: (usize, usize), - rcur: &StreamCursor, - ridx: (usize, usize), -) -> Ordering { - match (&lcur.on_rows.get(lidx.0), &rcur.on_rows.get(ridx.0)) { - (None, _) => Ordering::Greater, - (_, None) => Ordering::Less, - (Some(lrows), Some(rrows)) => { - let lkey = &lrows.row(lidx.1); - let rkey = &rrows.row(ridx.1); - match lkey.cmp(rkey) { - Ordering::Greater => Ordering::Greater, - Ordering::Less => Ordering::Less, - _ => { - if let Some(nb) = &lcur.on_row_null_buffers[lidx.0] { - if nb.is_null(lidx.1) { - return Ordering::Less; - } - } - Ordering::Equal - } - } - } - } + }}; } -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::{ - self, - array::*, - compute::SortOptions, - datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, - }; - use datafusion::{ - assert_batches_sorted_eq, - error::Result, - logical_expr::{JoinType, JoinType::*}, - physical_expr::expressions::Column, - physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, - prelude::SessionContext, - }; - - use crate::sort_merge_join_exec::SortMergeJoinExec; - - fn columns(schema: &Schema) -> Vec { - schema.fields().iter().map(|f| f.name().clone()).collect() - } - - fn build_table_i32( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Int32, false), - Field::new(b.0, DataType::Int32, false), - Field::new(c.0, DataType::Int32, false), - ]); - - RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap() - } - - fn build_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn build_table_from_batches(batches: Vec) -> Arc { - let schema = batches.first().unwrap().schema(); - Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) - } - - fn build_date_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date32, false), - Field::new(b.0, DataType::Date32, false), - Field::new(c.0, DataType::Date32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date32Array::from(a.1.clone())), - Arc::new(Date32Array::from(b.1.clone())), - Arc::new(Date32Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn build_date64_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date64, false), - Field::new(b.0, DataType::Date64, false), - Field::new(c.0, DataType::Date64, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date64Array::from(a.1.clone())), - Arc::new(Date64Array::from(b.1.clone())), - Arc::new(Date64Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - /// returns a table with 3 columns of i32 in memory - pub fn build_table_i32_nullable( - a: (&str, &Vec>), - b: (&str, &Vec>), - c: (&str, &Vec>), - ) -> Arc { - let schema = Arc::new(Schema::new(vec![ - Field::new(a.0, DataType::Int32, true), - Field::new(b.0, DataType::Int32, true), - Field::new(c.0, DataType::Int32, true), - ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn join_with_options( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - sort_options: Vec, - ) -> Result { - SortMergeJoinExec::try_new(left, right, on, join_type, None, sort_options) - } - - async fn join_collect( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - ) -> Result<(Vec, Vec)> { - let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options).await - } - - async fn join_collect_with_options( - left: Arc, - right: Arc, - on: JoinOn, - join_type: JoinType, - sort_options: Vec, - ) -> Result<(Vec, Vec)> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let join = join_with_options(left, right, on, join_type, sort_options)?; - let columns = columns(&join.schema()); - - let stream = join.execute(0, task_ctx)?; - let batches = common::collect(stream).await?; - Ok((columns, batches)) - } - - #[tokio::test] - async fn join_inner_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 5]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 5 | 9 | 20 | 5 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b2", &vec![1, 2, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 2, 3]), - ("b2", &vec![1, 2, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | 7 | 1 | 1 | 70 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 1, 2]), - ("b2", &vec![1, 1, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 1, 3]), - ("b2", &vec![1, 1, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | 7 | 1 | 1 | 70 |", - "| 1 | 1 | 7 | 1 | 1 | 80 |", - "| 1 | 1 | 8 | 1 | 1 | 70 |", - "| 1 | 1 | 8 | 1 | 1 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field - ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), - ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | | 1 | 1 | 70 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls_with_options() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field - ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), - ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - let (_, batches) = join_collect_with_options( - left, - right, - on, - Inner, - vec![ - SortOptions { - descending: true, - nulls_first: false - }; - 2 - ], - // null_equals_null=false - ) - .await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 1 | 1 | | 1 | 1 | 70 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 6 does not exist on the left - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| | | | 30 | 6 | 90 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_full_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b2", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 30 | 6 | 90 |", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_anti() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3, 5]), - ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9, 11]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, LeftAnti).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 3 | 7 | 9 |", - "| 5 | 7 | 11 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_semi() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 5 is double on the right - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, LeftSemi).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_with_duplicated_column_names() -> Result<()> { - let left = build_table( - ("a", &vec![1, 2, 3]), - ("b", &vec![4, 5, 7]), - ("c", &vec![7, 8, 9]), - ); - let right = build_table( - ("a", &vec![10, 20, 30]), - ("b", &vec![1, 2, 7]), - ("c", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - // join on a=b so there are duplicate column names on unjoined columns - Arc::new(Column::new_with_schema("a", &left.schema())?), - Arc::new(Column::new_with_schema("b", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+---+---+---+----+---+----+", - "| a | b | c | a | b | c |", - "+---+---+---+----+---+----+", - "| 1 | 4 | 7 | 10 | 1 | 70 |", - "| 2 | 5 | 8 | 20 | 2 | 80 |", - "+---+---+---+----+---+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_date32() -> Result<()> { - let left = build_date_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![19107, 19108, 19108]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![19107, 19108, 19109]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - let expected = vec![ - "+------------+------------+------------+------------+------------+------------+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+------------+------------+------------+------------+------------+------------+", - "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", - "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "+------------+------------+------------+------------+------------+------------+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_date64() -> Result<()> { - let left = build_date64_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date64_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", - "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - ]; - - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3, 4, 5]), - ("b1", &vec![3, 4, 5, 6, 6, 7]), - ("c1", &vec![4, 5, 6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30, 40]), - ("b2", &vec![2, 4, 6, 6, 8]), - ("c2", &vec![50, 60, 70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3]), - ("b1", &vec![3, 4, 5, 7]), - ("c1", &vec![6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30]), - ("b2", &vec![2, 4, 5, 6]), - ("c2", &vec![60, 70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 2 | 60 |", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| | | | 30 | 6 | 90 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "| 6 | 9 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_multiple_batches() -> Result<()> { - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 1, 2]), - ("b2", &vec![3, 4, 5]), - ("c2", &vec![4, 5, 6]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![3, 4, 5, 6]), - ("b2", &vec![6, 6, 7, 9]), - ("c2", &vec![7, 8, 9, 9]), - ); - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 10, 20]), - ("b1", &vec![2, 4, 6]), - ("c1", &vec![50, 60, 70]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![30, 40]), - ("b1", &vec![6, 8]), - ("c1", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 3 | 4 |", - "| 10 | 4 | 60 | 1 | 4 | 5 |", - "| | | | 2 | 5 | 6 |", - "| 20 | 6 | 70 | 3 | 6 | 7 |", - "| 30 | 6 | 80 | 3 | 6 | 7 |", - "| 20 | 6 | 70 | 4 | 6 | 8 |", - "| 30 | 6 | 80 | 4 | 6 | 8 |", - "| | | | 5 | 7 | 9 |", - "| | | | 6 | 9 | 9 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_full_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 2 | 50 |", - "| | | | 40 | 8 | 90 |", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "| 6 | 9 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } +#[async_trait] +pub trait Joiner { + async fn join(self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()>; + fn total_send_output_time(&self) -> usize; + fn num_output_rows(&self) -> usize; } diff --git a/pom.xml b/pom.xml index 8598cfd26..f5d8fa23d 100644 --- a/pom.xml +++ b/pom.xml @@ -13,7 +13,7 @@ - 2.0.9.1-SNAPSHOT + 3.0.0-SNAPSHOT UTF-8 15.0.2 3.21.9 @@ -107,6 +107,13 @@ + + + com.google.code.findbugs + jsr305 + 2.0.2 + + scala-compile-first diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 42d866a3b..1e25db4d0 100755 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -16,5 +16,5 @@ # under the License. [toolchain] -channel = "nightly-2023-08-01" -components = ["cargo", "rustfmt", "clippy"] +channel = "nightly-2024-06-27" +components = ["rust-src", "cargo", "rustfmt", "clippy"] diff --git a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 823a2bcb9..8dd9ea84b 100644 --- a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -79,8 +79,6 @@ import org.apache.spark.sql.execution.blaze.plan.NativeAggBase.AggExecMode import org.apache.spark.sql.execution.blaze.plan.NativeAggExec import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinExec -import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastNestedLoopJoinBase -import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastNestedLoopJoinExec import org.apache.spark.sql.execution.blaze.plan.NativeExpandBase import org.apache.spark.sql.execution.blaze.plan.NativeExpandExec import org.apache.spark.sql.execution.blaze.plan.NativeFilterBase @@ -114,6 +112,7 @@ import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.types.DataType import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.execution.blaze.plan.BroadcastSide import org.apache.spark.sql.execution.blaze.plan.NativeParquetSinkBase import org.apache.spark.sql.execution.blaze.plan.NativeParquetSinkExec import org.blaze.{protobuf => pb} @@ -153,7 +152,7 @@ class ShimsImpl extends Shims with Logging { leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase = + buildSide: BroadcastSide): NativeBroadcastJoinBase = NativeBroadcastJoinExec( left, right, @@ -161,14 +160,7 @@ class ShimsImpl extends Shims with Logging { leftKeys, rightKeys, joinType, - condition) - - override def createNativeBroadcastNestedLoopJoinExec( - left: SparkPlan, - right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]): NativeBroadcastNestedLoopJoinBase = - NativeBroadcastNestedLoopJoinExec(left, right, joinType, condition) + buildSide) override def createNativeSortMergeJoinExec( left: SparkPlan, diff --git a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala index 75d1b7c22..3101587d3 100644 --- a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala @@ -19,8 +19,9 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins import org.apache.spark.sql.execution.joins.BuildLeft +import org.apache.spark.sql.execution.joins.BuildRight +import org.apache.spark.sql.execution.joins.BuildSide import org.apache.spark.sql.execution.joins.HashJoin case class NativeBroadcastJoinExec( @@ -30,7 +31,7 @@ case class NativeBroadcastJoinExec( override val leftKeys: Seq[Expression], override val rightKeys: Seq[Expression], override val joinType: JoinType, - override val condition: Option[Expression]) + broadcastSide: BroadcastSide) extends NativeBroadcastJoinBase( left, right, @@ -38,10 +39,15 @@ case class NativeBroadcastJoinExec( leftKeys, rightKeys, joinType, - condition) + broadcastSide) with HashJoin { - override val buildSide: joins.BuildSide = BuildLeft + override val condition: Option[Expression] = None + + override val buildSide: BuildSide = broadcastSide match { + case BroadcastLeft => BuildLeft + case BroadcastRight => BuildRight + } override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(left = newChildren(0), right = newChildren(1)) diff --git a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala deleted file mode 100644 index 7b215cea0..000000000 --- a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.blaze.plan - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.execution.SparkPlan - -case class NativeBroadcastNestedLoopJoinExec( - override val left: SparkPlan, - override val right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]) - extends NativeBroadcastNestedLoopJoinBase(left, right, joinType, condition) { - - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(left = newChildren(0), right = newChildren(1)) -} diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 1d867a565..a394cf934 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -105,7 +105,6 @@ import org.apache.spark.sql.execution.blaze.plan.NativeWindowExec import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec -import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastNestedLoopJoinExec import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExec import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.types.DataType @@ -150,7 +149,7 @@ class ShimsImpl extends Shims with Logging { leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase = + broadcastSide: BroadcastSide): NativeBroadcastJoinBase = NativeBroadcastJoinExec( left, right, @@ -158,14 +157,7 @@ class ShimsImpl extends Shims with Logging { leftKeys, rightKeys, joinType, - condition) - - override def createNativeBroadcastNestedLoopJoinExec( - left: SparkPlan, - right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]): NativeBroadcastNestedLoopJoinBase = - NativeBroadcastNestedLoopJoinExec(left, right, joinType, condition) + broadcastSide) override def createNativeSortMergeJoinExec( left: SparkPlan, diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala index 292f23321..fdd3a2453 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala @@ -20,7 +20,6 @@ import java.io.InputStream import org.apache.spark.MapOutputTracker import org.apache.spark.SparkEnv import org.apache.spark.TaskContext - import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec @@ -28,30 +27,21 @@ import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockId import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.ShuffleBlockFetcherIterator class BlazeBlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], context: TaskContext, readMetrics: ShuffleReadMetricsReporter, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, - startMapId: Option[Int] = None, - endMapId: Option[Int] = None, shouldBatchFetch: Boolean = false) extends BlazeBlockStoreShuffleReaderBase[K, C](handle, context) with Logging { override def readBlocks(): Iterator[(BlockId, InputStream)] = { - val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, - startMapId.getOrElse(0), - endMapId.getOrElse(Int.MaxValue), - startPartition, - endPartition) - new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala index a7390ee01..83decb32a 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala @@ -22,6 +22,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,16 +55,27 @@ class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { if (isArrowShuffle(handle)) { + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + val (blocksByAddress, canEnableBatchFetch) = + if (baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked) { + val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + (res.iter, res.enableBatchFetch) + } else { + val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + (address, true) + } + new BlazeBlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, - endPartition, + blocksByAddress, context, metrics, SparkEnv.get.blockManager, SparkEnv.get.mapOutputTracker, - startMapId = Some(startMapIndex), - endMapId = Some(endMapIndex)) + shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)) } else { sortShuffleManager.getReader( handle, diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala index 3fc664959..de3f5f887 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala @@ -21,12 +21,16 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.BuildLeft +import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.blaze.plan.BroadcastLeft +import org.apache.spark.sql.execution.blaze.plan.BroadcastRight +import org.apache.spark.sql.execution.blaze.plan.BroadcastSide import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.execution.joins.HashedRelationInfo @@ -39,7 +43,7 @@ case class NativeBroadcastJoinExec( override val leftKeys: Seq[Expression], override val rightKeys: Seq[Expression], override val joinType: JoinType, - override val condition: Option[Expression]) + broadcastSide: BroadcastSide) extends NativeBroadcastJoinBase( left, right, @@ -47,9 +51,11 @@ case class NativeBroadcastJoinExec( leftKeys, rightKeys, joinType, - condition) + broadcastSide) with HashJoin { + override def condition: Option[Expression] = None + override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false) BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -65,7 +71,10 @@ case class NativeBroadcastJoinExec( throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") } - override def buildSide: BuildSide = BuildLeft + override def buildSide: BuildSide = broadcastSide match { + case BroadcastLeft => BuildLeft + case BroadcastRight => BuildRight + } override protected def withNewChildrenInternal( newLeft: SparkPlan, diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala deleted file mode 100644 index a129e91c0..000000000 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.joins.blaze.plan - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastNestedLoopJoinBase - -case class NativeBroadcastNestedLoopJoinExec( - override val left: SparkPlan, - override val right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]) - extends NativeBroadcastNestedLoopJoinBase(left, right, joinType, condition) { - - override protected def withNewChildrenInternal( - newLeft: SparkPlan, - newRight: SparkPlan): SparkPlan = - copy(left = newLeft, right = newRight) -} diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java index 31c3b9ac5..f7b1f97a7 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java @@ -27,15 +27,13 @@ public enum BlazeConf { /// actual off-heap memory usage is expected to be spark.executor.memoryOverhead * fraction. MEMORY_FRACTION("spark.blaze.memoryFraction", 0.6), - /// translates inequality smj to native. improves performance in most cases, however some - /// issues are found in special cases, like tpcds q72. - SMJ_INEQUALITY_JOIN_ENABLE("spark.blaze.enable.smjInequalityJoin", false), - /// fallbacks to SortMergeJoin when executing BroadcastHashJoin with big broadcasted table. - BHJ_FALLBACKS_TO_SMJ_ENABLE("spark.blaze.enable.bhjFallbacksToSmj", true), + /// not available in blaze 3.0+ + BHJ_FALLBACKS_TO_SMJ_ENABLE("spark.blaze.enable.bhjFallbacksToSmj", false), /// fallbacks to SortMergeJoin when BroadcastHashJoin has a broadcasted table with rows more /// than this threshold. requires spark.blaze.enable.bhjFallbacksToSmj = true. + /// not available in blaze 3.0+ BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD("spark.blaze.bhjFallbacksToSmj.rows", 1000000), /// fallbacks to SortMergeJoin when BroadcastHashJoin has a broadcasted table with memory usage @@ -44,7 +42,7 @@ public enum BlazeConf { /// enable converting upper/lower functions to native, special cases may provide different /// outputs from spark due to different unicode versions. - CASE_CONVERT_FUNCTIONS_ENABLE("spark.blaze.enable.caseconvert.functions", false), + CASE_CONVERT_FUNCTIONS_ENABLE("spark.blaze.enable.caseconvert.functions", true), /// number of threads evaluating UDFs /// improves performance for special case that UDF concurrency matters @@ -64,6 +62,12 @@ public enum BlazeConf { /// mininum number of rows to trigger partial aggregate skipping PARTIAL_AGG_SKIPPING_MIN_ROWS("spark.blaze.partialAggSkipping.minRows", BATCH_SIZE.intConf() * 2), + + // parquet enable page filtering + PARQUET_ENABLE_PAGE_FILTERING("spark.blaze.parquet.enable.pageFiltering", false), + + // parqeut enable bloom filter + PARQUET_ENABLE_BLOOM_FILTER("spark.blaze.parquet.enable.bloomFilter", false), ; private String key; diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala index 09bf85e04..c24888d63 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala @@ -28,7 +28,6 @@ import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.ArrowSchema import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.c.Data -import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.Partition diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala index 9f7eb610a..a7ab81de9 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala @@ -46,7 +46,8 @@ object BlazeConvertStrategy extends Logging { val convertibleTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.convertible") val convertStrategyTag: TreeNodeTag[ConvertStrategy] = TreeNodeTag("blaze.convert.strategy") - val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.child.ordering.required") + val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag( + "blaze.child.ordering.required") def apply(exec: SparkPlan): Unit = { exec.foreach(_.setTagValue(convertibleTag, true)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 99d0172a4..7e76b4391 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala @@ -15,11 +15,8 @@ */ package org.apache.spark.sql.blaze -import java.util.UUID - import scala.annotation.tailrec import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat import org.apache.spark.SparkEnv @@ -57,7 +54,6 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.plan.NativeAggBase -import org.apache.spark.sql.execution.blaze.plan.NativeProjectBase import org.apache.spark.sql.execution.blaze.plan.NativeUnionBase import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.command.DataWritingCommandExec @@ -128,7 +124,9 @@ object BlazeConverters extends Logging { var newExec = exec.withNewChildren(newChildren) exec.getTagValue(convertibleTag).foreach(newExec.setTagValue(convertibleTag, _)) exec.getTagValue(convertStrategyTag).foreach(newExec.setTagValue(convertStrategyTag, _)) - exec.getTagValue(childOrderingRequiredTag).foreach(newExec.setTagValue(childOrderingRequiredTag, _)) + exec + .getTagValue(childOrderingRequiredTag) + .foreach(newExec.setTagValue(childOrderingRequiredTag, _)) if (!isNeverConvert(newExec)) { newExec = convertSparkPlan(newExec) } @@ -333,45 +331,14 @@ object BlazeConverters extends Logging { val (leftKeys, rightKeys, joinType, condition, left, right) = (exec.leftKeys, exec.rightKeys, exec.joinType, exec.condition, exec.left, exec.right) logDebug(s"Converting SortMergeJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}") - var nativeLeft = convertToNative(left) - var nativeRight = convertToNative(right) - var modifiedLeftKeys = leftKeys - var modifiedRightKeys = rightKeys - var needPostProject = false - - if (leftKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeLeft, leftKeys) - modifiedLeftKeys = keys - nativeLeft = exec - needPostProject = true - } - if (rightKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeRight, rightKeys) - modifiedRightKeys = keys - nativeRight = exec - needPostProject = true - } - val smjOrig = SortMergeJoinExec( - modifiedLeftKeys, - modifiedRightKeys, + Shims.get.createNativeSortMergeJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + leftKeys, + rightKeys, joinType, - condition, - addRenameColumnsExec(nativeLeft), - addRenameColumnsExec(nativeRight)) - val smj = Shims.get.createNativeSortMergeJoinExec( - smjOrig.left, - smjOrig.right, - smjOrig.leftKeys, - smjOrig.rightKeys, - smjOrig.joinType, - smjOrig.condition) - - if (needPostProject) { - buildPostJoinProject(smj, exec.output) - } else { - smj - } + condition) } def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = { @@ -385,84 +352,33 @@ object BlazeConverters extends Logging { exec.left, exec.right) logDebug(s"Converting BroadcastHashJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}") - logDebug(s" leftKeys: ${exec.leftKeys}") - logDebug(s" rightKeys: ${exec.rightKeys}") - logDebug(s" joinType: ${exec.joinType}") - logDebug(s" buildSide: ${exec.buildSide}") - logDebug(s" condition: ${exec.condition}") - var (hashed, hashedKeys, nativeProbed, probedKeys) = buildSide match { + logDebug(s" leftKeys: $leftKeys") + logDebug(s" rightKeys: $rightKeys") + logDebug(s" joinType: $joinType") + logDebug(s" buildSide: $buildSide") + logDebug(s" condition: $condition") + assert(condition.isEmpty, "join condition is not supported") + + // verify build side is native + buildSide match { case BuildRight => assert(NativeHelper.isNative(right), "broadcast join build side is not native") - val convertedLeft = convertToNative(left) - (right, rightKeys, convertedLeft, leftKeys) - case BuildLeft => assert(NativeHelper.isNative(left), "broadcast join build side is not native") - val convertedRight = convertToNative(right) - (left, leftKeys, convertedRight, rightKeys) - - case _ => - // scalastyle:off throwerror - throw new NotImplementedError( - "Ignore BroadcastHashJoin with unsupported children structure") } - var modifiedHashedKeys = hashedKeys - var modifiedProbedKeys = probedKeys - var needPostProject = false - - if (hashedKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(hashed, hashedKeys) - modifiedHashedKeys = keys - hashed = exec - needPostProject = true - } - if (probedKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeProbed, probedKeys) - modifiedProbedKeys = keys - nativeProbed = exec - needPostProject = true - } + Shims.get.createNativeBroadcastJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + exec.outputPartitioning, + leftKeys, + rightKeys, + joinType, + buildSide match { + case BuildLeft => BroadcastLeft + case BuildRight => BroadcastRight + }) - val modifiedJoinType = buildSide match { - case BuildLeft => joinType - case BuildRight => - needPostProject = true - val modifiedJoinType = joinType match { // reverse join type - case Inner => Inner - case FullOuter => FullOuter - case LeftOuter => RightOuter - case RightOuter => LeftOuter - case _ => - throw new NotImplementedError( - "BHJ Semi/Anti join with BuildRight is not yet supported") - } - modifiedJoinType - } - - val bhjOrig = BroadcastHashJoinExec( - modifiedHashedKeys, - modifiedProbedKeys, - modifiedJoinType, - BuildLeft, - condition, - addRenameColumnsExec(hashed), - addRenameColumnsExec(nativeProbed)) - - val bhj = Shims.get.createNativeBroadcastJoinExec( - bhjOrig.left, - bhjOrig.right, - bhjOrig.outputPartitioning, - bhjOrig.leftKeys, - bhjOrig.rightKeys, - bhjOrig.joinType, - bhjOrig.condition) - - if (needPostProject) { - buildPostJoinProject(bhj, exec.output) - } else { - bhj - } } catch { case e @ (_: NotImplementedError | _: Exception) => val underlyingBroadcast = exec.buildSide match { @@ -483,60 +399,29 @@ object BlazeConverters extends Logging { logDebug(s" joinType: ${exec.joinType}") logDebug(s" buildSide: ${exec.buildSide}") logDebug(s" condition: ${exec.condition}") - val (broadcasted, nativeProbed) = buildSide match { + assert(condition.isEmpty, "join condition is not supported") + + // verify build side is native + buildSide match { case BuildRight => assert(NativeHelper.isNative(right), "broadcast join build side is not native") - val convertedLeft = convertToNative(left) - (right, convertedLeft) - case BuildLeft => assert(NativeHelper.isNative(left), "broadcast join build side is not native") - val convertedRight = convertToNative(right) - (left, convertedRight) - - case _ => - // scalastyle:off throwerror - throw new NotImplementedError( - "Ignore BroadcastNestedLoopJoin with unsupported children structure") - } - - // the in-memory inner table is not the same in different join types - // reference: https://docs.rs/datafusion/latest/datafusion/physical_plan/joins/struct.NestedLoopJoinExec.html - var needPostProject = false - val (modifiedLeft, modifiedRight, modifiedJoinType) = (buildSide, joinType) match { - case (BuildLeft, RightOuter | FullOuter) => - (broadcasted, nativeProbed, joinType) // RightOuter, FullOuter => BuildLeft - case (BuildRight, Inner | LeftOuter | LeftSemi | LeftAnti) => - ( - nativeProbed, - broadcasted, - joinType - ) // Inner, LeftOuter, LeftSemi, LeftAnti => BuildRight - case _ => - needPostProject = true - val modifiedJoinType = joinType match { - case Inner => - (nativeProbed, broadcasted, Inner) // Inner + BuildLeft => BuildRight - case FullOuter => - (broadcasted, nativeProbed, FullOuter) // FullOuter + BuildRight => BuildLeft - case _ => - throw new NotImplementedError( - s"BNLJ $joinType with $buildSide is not yet supported") - } - modifiedJoinType } - val bnlj = Shims.get.createNativeBroadcastNestedLoopJoinExec( - addRenameColumnsExec(modifiedLeft), - addRenameColumnsExec(modifiedRight), - modifiedJoinType, - condition) + // reuse NativeBroadcastJoin with empty equility keys + Shims.get.createNativeBroadcastJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + exec.outputPartitioning, + Nil, + Nil, + joinType, + buildSide match { + case BuildLeft => BroadcastLeft + case BuildRight => BroadcastRight + }) - if (needPostProject) { - buildPostJoinProject(bnlj, exec.output) - } else { - bnlj - } } catch { case e @ (_: NotImplementedError | _: Exception) => val underlyingBroadcast = exec.buildSide match { @@ -851,44 +736,6 @@ object BlazeConverters extends Logging { exec } - private def buildJoinColumnsProject( - child: SparkPlan, - joinKeys: Seq[Expression]): (Seq[AttributeReference], NativeProjectBase) = { - val extraProjectList = ArrayBuffer[NamedExpression]() - val transformedKeys = ArrayBuffer[AttributeReference]() - - joinKeys.foreach { - case attr: AttributeReference => transformedKeys.append(attr) - case expr => - val aliasExpr = - Alias(expr, s"JOIN_KEY:${expr.toString()} (${UUID.randomUUID().toString})")() - extraProjectList.append(aliasExpr) - - val attr = AttributeReference( - aliasExpr.name, - aliasExpr.dataType, - aliasExpr.nullable, - aliasExpr.metadata)(aliasExpr.exprId, aliasExpr.qualifier) - transformedKeys.append(attr) - } - ( - transformedKeys, - Shims.get - .createNativeProjectExec(child.output ++ extraProjectList, addRenameColumnsExec(child))) - } - - private def buildPostJoinProject( - child: SparkPlan, - output: Seq[Attribute]): NativeProjectBase = { - val projectList = output - .filter(!_.name.startsWith("JOIN_KEY:")) - .map(attr => - AttributeReference(attr.name, attr.dataType, attr.nullable, attr.metadata)( - attr.exprId, - attr.qualifier)) - Shims.get.createNativeProjectExec(projectList, child) - } - private def getPartialAggProjection( aggregateExprs: Seq[AggregateExpression], groupingExprs: Seq[NamedExpression]) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 1cbfcc8b0..355444a7d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -52,12 +52,11 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.ScalarSubquery -import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.hive.blaze.HiveUDFUtil import org.apache.spark.sql.hive.blaze.HiveUDFUtil.getFunctionClassName -import org.apache.spark.sql.hive.blaze.HiveUDFUtil.isHiveSimpleUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.AtomicType @@ -1110,6 +1109,7 @@ object NativeConverters extends Logging { case FullOuter => pb.JoinType.FULL case LeftSemi => pb.JoinType.SEMI case LeftAnti => pb.JoinType.ANTI + case _: ExistenceJoin => pb.JoinType.EXISTENCE case _ => throw new NotImplementedError(s"unsupported join type: ${joinType}") } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index a8aaad261..12deb83c6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala @@ -79,13 +79,7 @@ abstract class Shims { leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase - - def createNativeBroadcastNestedLoopJoinExec( - left: SparkPlan, - right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]): NativeBroadcastNestedLoopJoinBase + broadcastSide: BroadcastSide): NativeBroadcastJoinBase def createNativeSortMergeJoinExec( left: SparkPlan, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index ff27e53ca..e9498137c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala @@ -16,7 +16,6 @@ package org.apache.spark.sql.blaze import java.nio.ByteBuffer - import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.Data import org.apache.arrow.vector.VectorSchemaRoot diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala index b78eb080d..b1039694d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala @@ -19,15 +19,14 @@ import scala.util.control.{ControlThrowable, NonFatal} import scala.util.Try /** - * A utility for performing automatic resource management. It can be used to perform an - * operation using resources, after which it releases the resources in reverse order - * of their creation. + * A utility for performing automatic resource management. It can be used to perform an operation + * using resources, after which it releases the resources in reverse order of their creation. * * ==Usage== * - * There are multiple ways to automatically manage resources with `Using`. If you only need - * to manage a single resource, the [[Using.apply `apply`]] method is easiest; it wraps the - * resource opening, operation, and resource releasing in a `Try`. + * There are multiple ways to automatically manage resources with `Using`. If you only need to + * manage a single resource, the [[Using.apply `apply`]] method is easiest; it wraps the resource + * opening, operation, and resource releasing in a `Try`. * * Example: * {{{ @@ -37,9 +36,9 @@ import scala.util.Try * } * }}} * - * If you need to manage multiple resources, [[Using.Manager$.apply `Using.Manager`]] should - * be used. It allows the managing of arbitrarily many resources, whose creation, use, and - * release are all wrapped in a `Try`. + * If you need to manage multiple resources, [[Using.Manager$.apply `Using.Manager`]] should be + * used. It allows the managing of arbitrarily many resources, whose creation, use, and release + * are all wrapped in a `Try`. * * Example: * {{{ @@ -70,43 +69,44 @@ import scala.util.Try * * ==Suppression Behavior== * - * If two exceptions are thrown (e.g., by an operation and closing a resource), - * one of them is re-thrown, and the other is - * [[java.lang.Throwable.addSuppressed(Throwable) added to it as a suppressed exception]]. - * If the two exceptions are of different 'severities' (see below), the one of a higher - * severity is re-thrown, and the one of a lower severity is added to it as a suppressed - * exception. If the two exceptions are of the same severity, the one thrown first is - * re-thrown, and the one thrown second is added to it as a suppressed exception. - * If an exception is a [[scala.util.control.ControlThrowable `ControlThrowable`]], or - * if it does not support suppression (see - * [[java.lang.Throwable `Throwable`'s constructor with an `enableSuppression` parameter]]), - * an exception that would have been suppressed is instead discarded. + * If two exceptions are thrown (e.g., by an operation and closing a resource), one of them is + * re-thrown, and the other is + * [[java.lang.Throwable.addSuppressed(Throwable) added to it as a suppressed exception]]. If the + * two exceptions are of different 'severities' (see below), the one of a higher severity is + * re-thrown, and the one of a lower severity is added to it as a suppressed exception. If the two + * exceptions are of the same severity, the one thrown first is re-thrown, and the one thrown + * second is added to it as a suppressed exception. If an exception is a + * [[scala.util.control.ControlThrowable `ControlThrowable`]], or if it does not support + * suppression (see + * [[java.lang.Throwable `Throwable`'s constructor with an `enableSuppression` parameter]]), an + * exception that would have been suppressed is instead discarded. * * Exceptions are ranked from highest to lowest severity as follows: * - `java.lang.VirtualMachineError` * - `java.lang.LinkageError` * - `java.lang.InterruptedException` and `java.lang.ThreadDeath` - * - [[scala.util.control.NonFatal fatal exceptions]], excluding `scala.util.control.ControlThrowable` + * - [[scala.util.control.NonFatal fatal exceptions]], excluding + * `scala.util.control.ControlThrowable` * - `scala.util.control.ControlThrowable` * - all other exceptions * - * When more than two exceptions are thrown, the first two are combined and - * re-thrown as described above, and each successive exception thrown is combined - * as it is thrown. + * When more than two exceptions are thrown, the first two are combined and re-thrown as described + * above, and each successive exception thrown is combined as it is thrown. * - * @define suppressionBehavior See the main doc for [[Using `Using`]] for full details of - * suppression behavior. + * @define suppressionBehavior + * See the main doc for [[Using `Using`]] for full details of suppression behavior. */ object Using { /** - * Performs an operation using a resource, and then releases the resource, - * even if the operation throws an exception. + * Performs an operation using a resource, and then releases the resource, even if the operation + * throws an exception. * * $suppressionBehavior * - * @return a [[Try]] containing an exception if one or more were thrown, - * or the result of the operation if no exceptions were thrown + * @return + * a [[Try]] containing an exception if one or more were thrown, or the result of the + * operation if no exceptions were thrown */ def apply[R: Releasable, A](resource: => R)(f: R => A): Try[A] = Try { Using.resource(resource)(f) @@ -115,20 +115,20 @@ object Using { /** * A resource manager. * - * Resources can be registered with the manager by calling [[acquire `acquire`]]; - * such resources will be released in reverse order of their acquisition - * when the manager is closed, regardless of any exceptions thrown - * during use. + * Resources can be registered with the manager by calling [[acquire `acquire`]]; such resources + * will be released in reverse order of their acquisition when the manager is closed, regardless + * of any exceptions thrown during use. * * $suppressionBehavior * - * @note It is recommended for API designers to require an implicit `Manager` - * for the creation of custom resources, and to call `acquire` during those - * resources' construction. Doing so guarantees that the resource ''must'' be - * automatically managed, and makes it impossible to forget to do so. + * @note + * It is recommended for API designers to require an implicit `Manager` for the creation of + * custom resources, and to call `acquire` during those resources' construction. Doing so + * guarantees that the resource ''must'' be automatically managed, and makes it impossible to + * forget to do so. * - * Example: - * {{{ + * Example: + * {{{ * class SafeFileReader(file: File)(implicit manager: Using.Manager) * extends BufferedReader(new FileReader(file)) { * @@ -136,7 +136,7 @@ object Using { * * manager.acquire(this) * } - * }}} + * }}} */ final class Manager private { import Manager._ @@ -145,9 +145,8 @@ object Using { private[this] var resources: List[Resource[_]] = Nil /** - * Registers the specified resource with this manager, so that - * the resource is released when the manager is closed, and then - * returns the (unmodified) resource. + * Registers the specified resource with this manager, so that the resource is released when + * the manager is closed, and then returns the (unmodified) resource. */ def apply[R: Releasable](resource: R): R = { acquire(resource) @@ -155,8 +154,8 @@ object Using { } /** - * Registers the specified resource with this manager, so that - * the resource is released when the manager is closed. + * Registers the specified resource with this manager, so that the resource is released when + * the manager is closed. */ def acquire[R: Releasable](resource: R): Unit = { if (resource == null) throw new NullPointerException("null resource") @@ -194,8 +193,8 @@ object Using { object Manager { /** - * Performs an operation using a `Manager`, then closes the `Manager`, - * releasing its resources (in reverse order of acquisition). + * Performs an operation using a `Manager`, then closes the `Manager`, releasing its resources + * (in reverse order of acquisition). * * Example: * {{{ @@ -204,9 +203,8 @@ object Using { * } * }}} * - * If using resources which require an implicit `Manager` as a parameter, - * this method should be invoked with an `implicit` modifier before the function - * parameter: + * If using resources which require an implicit `Manager` as a parameter, this method should + * be invoked with an `implicit` modifier before the function parameter: * * Example: * {{{ @@ -217,10 +215,13 @@ object Using { * * See the main doc for [[Using `Using`]] for full details of suppression behavior. * - * @param op the operation to perform using the manager - * @tparam A the return type of the operation - * @return a [[Try]] containing an exception if one or more were thrown, - * or the result of the operation if no exceptions were thrown + * @param op + * the operation to perform using the manager + * @tparam A + * the return type of the operation + * @return + * a [[Try]] containing an exception if one or more were thrown, or the result of the + * operation if no exceptions were thrown */ def apply[A](op: Manager => A): Try[A] = Try { (new Manager).manage(op) } @@ -247,18 +248,21 @@ object Using { } /** - * Performs an operation using a resource, and then releases the resource, - * even if the operation throws an exception. This method behaves similarly - * to Java's try-with-resources. + * Performs an operation using a resource, and then releases the resource, even if the operation + * throws an exception. This method behaves similarly to Java's try-with-resources. * * $suppressionBehavior * - * @param resource the resource - * @param body the operation to perform with the resource - * @tparam R the type of the resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resource throws + * @param resource + * the resource + * @param body + * the operation to perform with the resource + * @tparam R + * the type of the resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resource throws */ def resource[R, A](resource: R)(body: R => A)(implicit releasable: Releasable[R]): A = { if (resource == null) throw new NullPointerException("null resource") @@ -281,20 +285,26 @@ object Using { } /** - * Performs an operation using two resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using two resources, and then releases the resources in reverse order, + * even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, A](resource1: R1, resource2: => R2)( body: (R1, R2) => A): A = @@ -305,22 +315,30 @@ object Using { } /** - * Performs an operation using three resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using three resources, and then releases the resources in reverse + * order, even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param resource3 the third resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam R3 the type of the third resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param resource3 + * the third resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam R3 + * the type of the third resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, R3: Releasable, A]( resource1: R1, @@ -335,24 +353,34 @@ object Using { } /** - * Performs an operation using four resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using four resources, and then releases the resources in reverse order, + * even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param resource3 the third resource - * @param resource4 the fourth resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam R3 the type of the third resource - * @tparam R4 the type of the fourth resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param resource3 + * the third resource + * @param resource4 + * the fourth resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam R3 + * the type of the third resource + * @tparam R4 + * the type of the fourth resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, R3: Releasable, R4: Releasable, A]( resource1: R1, @@ -372,17 +400,18 @@ object Using { /** * A typeclass describing how to release a particular type of resource. * - * A resource is anything which needs to be released, closed, or otherwise cleaned up - * in some way after it is finished being used, and for which waiting for the object's - * garbage collection to be cleaned up would be unacceptable. For example, an instance of - * [[java.io.OutputStream]] would be considered a resource, because it is important to close - * the stream after it is finished being used. + * A resource is anything which needs to be released, closed, or otherwise cleaned up in some + * way after it is finished being used, and for which waiting for the object's garbage + * collection to be cleaned up would be unacceptable. For example, an instance of + * [[java.io.OutputStream]] would be considered a resource, because it is important to close the + * stream after it is finished being used. * - * An instance of `Releasable` is needed in order to automatically manage a resource - * with [[Using `Using`]]. An implicit instance is provided for all types extending + * An instance of `Releasable` is needed in order to automatically manage a resource with + * [[Using `Using`]]. An implicit instance is provided for all types extending * [[java.lang.AutoCloseable]]. * - * @tparam R the type of the resource + * @tparam R + * the type of the resource */ trait Releasable[-R] { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala index 6e18f4712..f6ddfc607 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala @@ -15,7 +15,8 @@ */ package org.apache.spark.sql.execution.blaze.arrowio.util -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.JavaConverters.seqAsJavaListConverter import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.memory.RootAllocator @@ -31,7 +32,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.ShutdownHookManager object ArrowUtils { - val rootAllocator = new RootAllocator(Long.MaxValue) ShutdownHookManager.addShutdownHook(() => rootAllocator.close()) @@ -128,7 +128,7 @@ object ArrowUtils { ArrayType(elementType, containsNull = elementField.isNullable) case ArrowType.Struct.INSTANCE => - val fields = field.getChildren().asScala.map { child => + val fields = field.getChildren.asScala.map { child => val dt = fromArrowField(child) StructField(child.getName, dt, child.isNullable) } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala index 052262325..852e5332d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.execution.blaze.arrowio.ArrowFFIExportIterator import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.OneToOneDependency -import org.apache.spark.sql.blaze.BlazeConf import org.blaze.protobuf.FFIReaderExecNode import org.blaze.protobuf.PhysicalPlanNode import org.blaze.protobuf.Schema diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala index 5525fb473..6fcbd4786 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala @@ -24,22 +24,19 @@ import java.util.concurrent.Future import java.util.concurrent.TimeoutException import java.util.concurrent.TimeUnit -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import scala.collection.immutable.SortedMap import scala.concurrent.Promise +import org.apache.commons.lang3.reflect.MethodUtils import org.apache.spark.OneToOneDependency import org.apache.spark.Partition import org.apache.spark.SparkException import org.apache.spark.TaskContext import org.apache.spark.broadcast -import org.blaze.{protobuf => pb} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.blaze.BlazeCallNativeWrapper -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.JniBridge import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters @@ -49,7 +46,10 @@ import org.apache.spark.sql.blaze.NativeSupports import org.apache.spark.sql.blaze.Shims import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.plans.physical.BroadcastPartitioning import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode @@ -63,6 +63,8 @@ import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.BinaryType +import org.blaze.{protobuf => pb} abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val child: SparkPlan) extends BroadcastExchangeLike @@ -71,10 +73,15 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + def broadcastMode: BroadcastMode = this.mode + + protected val hashMapOutput: Seq[Attribute] = output + .map(_.withNullability(true)) :+ AttributeReference("~TABLE", BinaryType, nullable = true)() + protected val nativeSchema: pb.Schema = Util.getNativeSchema(output) + protected val nativeHashMapSchema: pb.Schema = Util.getNativeSchema(hashMapOutput) def getRunId: UUID - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper .getDefaultNativeMetrics(sparkContext) @@ -93,9 +100,6 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi override def doPrepare(): Unit = { // Materialize the future. relationFuture - relationFuture - relationFuture - relationFuture } override def doExecuteBroadcast[T](): Broadcast[T] = { @@ -103,17 +107,31 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi override def index: Int = 0 } val broadcastReadNativePlan = doExecuteNative().nativePlan(singlePartition, null) - val rows = NativeHelper.executeNativePlan( + val rowsIter = NativeHelper.executeNativePlan( broadcastReadNativePlan, MetricNode(Map(), Nil, None), singlePartition, None) - val v = mode.transform(rows.toArray) - + val pruneKeyField = new InterpretedUnsafeProjection( + output.zipWithIndex + .map(v => BoundReference(v._2, v._1.dataType, v._1.nullable)) + .toArray) + + val dataRows = rowsIter + .map(pruneKeyField) + .map(_.copy()) + .toArray + + val broadcast = relationFuture.get // bloadcast must be resolved + val v = mode.transform(dataRows) val dummyBroadcasted = new Broadcast[Any](-1) { override protected def getValue(): Any = v - override protected def doUnpersist(blocking: Boolean): Unit = {} - override protected def doDestroy(blocking: Boolean): Unit = {} + override protected def doUnpersist(blocking: Boolean): Unit = { + MethodUtils.invokeMethod(broadcast, true, "doUnpersist", Array(blocking)) + } + override protected def doDestroy(blocking: Boolean): Unit = { + MethodUtils.invokeMethod(broadcast, true, "doDestroy", Array(blocking)) + } } dummyBroadcasted.asInstanceOf[Broadcast[T]] } @@ -154,13 +172,14 @@ abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val chi Channels.newChannel(new ByteArrayInputStream(bytes)) }) } + JniBridge.resourcesMap.put(resourceId, () => provideIpcIterator()) pb.PhysicalPlanNode .newBuilder() .setIpcReader( pb.IpcReaderExecNode .newBuilder() - .setSchema(nativeSchema) + .setSchema(nativeHashMapSchema) .setNumPartitions(1) .setIpcProviderResourceId(resourceId) .build()) @@ -267,39 +286,21 @@ object NativeBroadcastExchangeBase { keys: Seq[Expression], nativeSchema: pb.Schema): Array[Array[Byte]] = { - if (!BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || keys.isEmpty) { - return collectedData // no need to sort data in driver side - } - val readerIpcProviderResourceId = s"BuildBroadcastDataReader:${UUID.randomUUID()}" val readerExec = pb.IpcReaderExecNode .newBuilder() .setSchema(nativeSchema) .setIpcProviderResourceId(readerIpcProviderResourceId) - val sortExec = pb.SortExecNode + val buildHashMapExec = pb.BroadcastJoinBuildHashMapExecNode .newBuilder() .setInput(pb.PhysicalPlanNode.newBuilder().setIpcReader(readerExec)) - .addAllExpr( - keys - .map(key => { - pb.PhysicalExprNode - .newBuilder() - .setSort( - pb.PhysicalSortExprNode - .newBuilder() - .setExpr(NativeConverters.convertExpr(key)) - .setAsc(true) - .setNullsFirst(true) - .build()) - .build() - }) - .asJava) + .addAllKeys(keys.map(key => NativeConverters.convertExpr(key)).asJava) val writerIpcProviderResourceId = s"BuildBroadcastDataWriter:${UUID.randomUUID()}" val writerExec = pb.IpcWriterExecNode .newBuilder() - .setInput(pb.PhysicalPlanNode.newBuilder().setSort(sortExec)) + .setInput(pb.PhysicalPlanNode.newBuilder().setBroadcastJoinBuildHashMap(buildHashMapExec)) .setIpcConsumerResourceId(writerIpcProviderResourceId) // build native sorter diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala index ec13b8fae..dc27d2aad 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala @@ -20,21 +20,24 @@ import scala.collection.immutable.SortedMap import org.apache.spark.OneToOneDependency import org.apache.spark.Partition -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports +import org.apache.spark.sql.blaze.Shims +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.BinaryExecNode +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.types.LongType import org.blaze.{protobuf => pb} +import org.blaze.protobuf.JoinOn abstract class NativeBroadcastJoinBase( override val left: SparkPlan, @@ -43,82 +46,114 @@ abstract class NativeBroadcastJoinBase( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]) + broadcastSide: BroadcastSide) extends BinaryExecNode with NativeSupports { - assert( - (joinType != LeftSemi && joinType != LeftAnti) || condition.isEmpty, - "Semi/Anti join with filter is not supported yet") - - assert( - !BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE - .booleanConf() || condition.isEmpty, - "Join filter is not supported when BhjFallbacksToSmj and SmjInequalityJoin both enabled") - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper .getDefaultNativeMetrics(sparkContext) .toSeq: _*) + private val isLongHashRelation = { + val baseBroadcast = broadcastSide match { + case BroadcastLeft => Shims.get.getUnderlyingBroadcast(left) + case BroadcastRight => Shims.get.getUnderlyingBroadcast(right) + } + val mode = baseBroadcast match { + case b: BroadcastExchangeExec => b.mode + case b: NativeBroadcastExchangeBase => b.broadcastMode + } + mode match { + case mode: HashedRelationBroadcastMode + if mode.key.length == 1 && mode.key.head.dataType == LongType => + true + case _ => false + } + } + + private def nativeSchema = Util.getNativeSchema(output) + private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) => - val leftColumn = NativeConverters.convertExpr(leftKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"BHJ leftKey is not column: ${leftKey}") - case column => column + val leftKeyExpr = leftKey match { + case k if !isLongHashRelation || k.dataType == LongType => k + case k => Cast(k, LongType) } - val rightColumn = NativeConverters.convertExpr(rightKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"BHJ rightKey is not column: ${rightKey}") - case column => column + val rightKeyExpr = rightKey match { + case k if !isLongHashRelation || k.dataType == LongType => k + case k => Cast(k, LongType) } - pb.JoinOn + JoinOn .newBuilder() - .setLeft(leftColumn) - .setRight(rightColumn) + .setLeft(NativeConverters.convertExpr(leftKeyExpr)) + .setRight(NativeConverters.convertExpr(rightKeyExpr)) .build() } private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) + private def nativeBroadcastSide = broadcastSide match { + case BroadcastLeft => pb.JoinSide.LEFT_SIDE + case BroadcastRight => pb.JoinSide.RIGHT_SIDE + } // check whether native converting is supported + nativeSchema nativeJoinType - nativeJoinFilter + nativeJoinOn + nativeBroadcastSide override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) val rightRDD = NativeHelper.executeNative(right) val nativeMetrics = MetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) + val nativeSchema = this.nativeSchema val nativeJoinType = this.nativeJoinType val nativeJoinOn = this.nativeJoinOn - val nativeJoinFilter = this.nativeJoinFilter - val partitions = rightRDD.partitions + + val (probedRDD, builtRDD) = broadcastSide match { + case BroadcastLeft => (rightRDD, leftRDD) + case BroadcastRight => (leftRDD, rightRDD) + } new NativeRDD( sparkContext, nativeMetrics, - partitions, - rddDependencies = new OneToOneDependency(rightRDD) :: Nil, - rightRDD.isShuffleReadFull, + probedRDD.partitions, + rddDependencies = new OneToOneDependency(probedRDD) :: Nil, + probedRDD.isShuffleReadFull, (partition, context) => { val partition0 = new Partition() { override def index: Int = 0 } - val leftChild = leftRDD.nativePlan(partition0, context) - val rightChild = rightRDD.nativePlan(rightRDD.partitions(partition.index), context) + val (leftChild, rightChild) = broadcastSide match { + case BroadcastLeft => + ( + leftRDD.nativePlan(partition0, context), + rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) + case BroadcastRight => + ( + leftRDD.nativePlan(leftRDD.partitions(partition.index), context), + rightRDD.nativePlan(partition0, context)) + } + val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}" + val broadcastJoinExec = pb.BroadcastJoinExecNode .newBuilder() + .setSchema(nativeSchema) .setLeft(leftChild) .setRight(rightChild) .setJoinType(nativeJoinType) + .setBroadcastSide(nativeBroadcastSide) + .setCachedBuildHashMapId(cachedBuildHashMapId) .addAllOn(nativeJoinOn.asJava) - nativeJoinFilter.foreach(joinFilter => broadcastJoinExec.setJoinFilter(joinFilter)) pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build() }, friendlyName = "NativeRDD.BroadcastJoin") } } + +class BroadcastSide {} +case object BroadcastLeft extends BroadcastSide {} +case object BroadcastRight extends BroadcastSide {} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinBase.scala deleted file mode 100644 index bfcf74752..000000000 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinBase.scala +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.blaze.plan - -import scala.collection.immutable.SortedMap - -import org.apache.spark.OneToOneDependency -import org.apache.spark.Partition -import org.apache.spark.sql.blaze.MetricNode -import org.apache.spark.sql.blaze.NativeConverters -import org.apache.spark.sql.blaze.NativeHelper -import org.apache.spark.sql.blaze.NativeRDD -import org.apache.spark.sql.blaze.NativeSupports -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.ExistenceJoin -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.InnerLike -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftExistence -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.BinaryExecNode -import org.blaze.{protobuf => pb} - -abstract class NativeBroadcastNestedLoopJoinBase( - override val left: SparkPlan, - override val right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]) - extends BinaryExecNode - with NativeSupports { - - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( - NativeHelper - .getDefaultNativeMetrics(sparkContext) - .filterKeys( - Set( - "stage_id", - "output_rows", - "elapsed_compute", - "input_batch_count", - "input_batch_mem_size", - "input_row_count")) - .toSeq: _*) - - override def output: Seq[Attribute] = { - joinType match { - case _: InnerLike => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case j: ExistenceJoin => - left.output :+ j.exists - case LeftExistence(_) => - left.output - case x => - throw new IllegalArgumentException( - s"BroadcastNestedLoopJoin should not take $x as the JoinType") - } - } - - private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) - - // check whether native converting is supported - nativeJoinType - nativeJoinFilter - - private val probedSide = joinType match { - case Inner | LeftOuter | LeftSemi | LeftAnti => "left" - case RightOuter | FullOuter => "right" - case other => s"NativeBroadcastNestedLoopJoin does not support join type $other" - } - - override def doExecuteNative(): NativeRDD = { - val leftRDD = NativeHelper.executeNative(left) - val rightRDD = NativeHelper.executeNative(right) - val nativeMetrics = MetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) - val nativeJoinType = this.nativeJoinType - val nativeJoinFilter = this.nativeJoinFilter - val partitions = probedSide match { - case "left" => leftRDD.partitions - case "right" => rightRDD.partitions - } - - new NativeRDD( - sparkContext, - nativeMetrics, - partitions, - rddDependencies = probedSide match { - case "left" => new OneToOneDependency(leftRDD) :: Nil - case "right" => new OneToOneDependency(rightRDD) :: Nil - }, - rightRDD.isShuffleReadFull, - (partition, context) => { - val partition0 = new Partition() { - override def index: Int = 0 - } - val (leftChild, rightChild) = probedSide match { - case "left" => - ( - leftRDD.nativePlan(leftRDD.partitions(partition.index), context), - rightRDD.nativePlan(partition0, context)) - case "right" => - ( - leftRDD.nativePlan(partition0, context), - rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) - } - val bnlj = pb.BroadcastNestedLoopJoinExecNode - .newBuilder() - .setLeft(leftChild) - .setRight(rightChild) - .setJoinType(nativeJoinType) - - nativeJoinFilter.foreach(joinFilter => bnlj.setJoinFilter(joinFilter)) - pb.PhysicalPlanNode.newBuilder().setBroadcastNestedLoopJoin(bnlj).build() - }, - friendlyName = "NativeRDD.BroadcastNestedLoopJoin") - } -} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala index 2349cc9e8..dc0e371a4 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala @@ -22,7 +22,6 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.OneToOneDependency import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters -import org.apache.spark.sql.blaze.NativeConverters.convertExprWithFallback import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala index 1bbfdc3a5..276fd532d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala @@ -167,7 +167,7 @@ abstract class NativeParquetScanBase(basedFileScan: FileSourceScanExec) partitions.asInstanceOf[Array[Partition]], Nil, rddShuffleReadFull = true, - (partition, context) => { + (partition, _context) => { val resourceId = s"NativeParquetScanExec:${UUID.randomUUID().toString}" val sharedConf = broadcastedHadoopConf.value.value JniBridge.resourcesMap.put( diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala index 8e81d43f7..cd2e5a3ff 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala @@ -151,6 +151,6 @@ abstract class NativeParquetSinkBase( "ParquetSink") } - protected def newHadoopConf(tableDesc: TableDesc): Configuration = + protected def newHadoopConf(_tableDesc: TableDesc): Configuration = sparkSession.sessionState.newHadoopConf() } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala index 52efbcdc0..831211b06 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.OneToOneDependency -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper @@ -52,13 +51,7 @@ abstract class NativeSortMergeJoinBase( extends BinaryExecNode with NativeSupports { - assert( - (joinType != LeftSemi && joinType != LeftAnti) || condition.isEmpty, - "Semi/Anti join with filter is not supported yet") - - assert( - BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE.booleanConf() || condition.isEmpty, - "inequality sort-merge join is not enabled") + assert(condition.isEmpty, "inequality join is not supported") override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper @@ -81,21 +74,15 @@ abstract class NativeSortMergeJoinBase( keys.map(SortOrder(_, Ascending)) } + private def nativeSchema = Util.getNativeSchema(output) + private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) => - val leftColumn = NativeConverters.convertExpr(leftKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"SMJ leftKey is not column: ${leftKey}") - case column => column - } - val rightColumn = NativeConverters.convertExpr(rightKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"SMJ rightKey is not column: ${rightKey}") - case column => column - } + val leftKeyExpr = NativeConverters.convertExpr(leftKey) + val rightKeyExpr = NativeConverters.convertExpr(rightKey) JoinOn .newBuilder() - .setLeft(leftColumn) - .setRight(rightColumn) + .setLeft(leftKeyExpr) + .setRight(rightKeyExpr) .build() } @@ -109,14 +96,11 @@ abstract class NativeSortMergeJoinBase( private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) - // check whether native converting is supported + nativeSchema nativeSortOptions nativeJoinOn nativeJoinType - nativeJoinFilter override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) @@ -125,7 +109,6 @@ abstract class NativeSortMergeJoinBase( val nativeSortOptions = this.nativeSortOptions val nativeJoinOn = this.nativeJoinOn val nativeJoinType = this.nativeJoinType - val nativeJoinFilter = this.nativeJoinFilter val partitions = if (joinType != RightOuter) { leftRDD.partitions @@ -161,13 +144,12 @@ abstract class NativeSortMergeJoinBase( val sortMergeJoinExec = SortMergeJoinExecNode .newBuilder() + .setSchema(nativeSchema) .setLeft(leftChild) .setRight(rightChild) .setJoinType(nativeJoinType) .addAllOn(nativeJoinOn.asJava) .addAllSortOptions(nativeSortOptions.asJava) - - nativeJoinFilter.foreach(joinFilter => sortMergeJoinExec.setJoinFilter(joinFilter)) PhysicalPlanNode.newBuilder().setSortMergeJoin(sortMergeJoinExec).build() }, friendlyName = "NativeRDD.SortMergeJoin")