From 50fd3d0f8d154890517c17785542d54f15c583d8 Mon Sep 17 00:00:00 2001 From: Zhang Li Date: Tue, 28 Nov 2023 14:08:29 +0800 Subject: [PATCH] minor fixes and reafcatoring (#333) * code refactoring supports partial aggregate skipping fix incorrect assertion when converting concat_ws function fix ffi "not all nodes and buffers were consumed" issues * supports nested type hashing supports nested type array() function * use arrow snapshot version --------- Co-authored-by: zhangli20 --- Cargo.lock | 47 +- Cargo.toml | 2 +- dev/mvn-build-helper/assembly/pom.xml | 3 + native-engine/blaze-jni-bridge/src/conf.rs | 75 +++ .../blaze-jni-bridge/src/jni_bridge.rs | 99 ++-- native-engine/blaze-jni-bridge/src/lib.rs | 7 +- native-engine/blaze-serde/proto/blaze.proto | 1 + native-engine/blaze-serde/src/error.rs | 8 +- native-engine/blaze-serde/src/from_proto.rs | 137 +++--- native-engine/blaze-serde/src/lib.rs | 20 +- native-engine/blaze/src/exec.rs | 43 +- native-engine/blaze/src/lib.rs | 6 +- native-engine/blaze/src/metrics.rs | 6 +- native-engine/blaze/src/rt.rs | 126 ++--- .../datafusion-ext-commons/Cargo.toml | 4 + .../src/array_builder.rs | 7 +- .../src}/bytes_arena.rs | 5 +- .../datafusion-ext-commons/src/cast.rs | 47 +- .../datafusion-ext-commons/src/ffi.rs | 53 --- .../datafusion-ext-commons/src/hadoop_fs.rs | 3 +- .../src/io/batch_serde.rs | 44 +- .../datafusion-ext-commons/src/io/mod.rs | 12 +- .../datafusion-ext-commons/src/lib.rs | 20 +- .../src}/rdxsort.rs | 3 +- .../src}/slim_bytes.rs | 3 +- .../datafusion-ext-commons/src/spark_hash.rs | 448 +++++++----------- .../src/streams/coalesce_stream.rs | 63 ++- .../src/streams/ffi_stream.rs | 33 +- .../src/streams/ipc_stream.rs | 36 +- .../datafusion-ext-commons/src/uda.rs | 16 +- .../datafusion-ext-exprs/src/cast.rs | 47 +- .../src/get_indexed_field.rs | 48 +- .../datafusion-ext-exprs/src/get_map_value.rs | 91 ++-- native-engine/datafusion-ext-exprs/src/lib.rs | 4 +- .../datafusion-ext-exprs/src/named_struct.rs | 58 ++- .../src/spark_scalar_subquery_wrapper.rs | 28 +- .../src/spark_udf_wrapper.rs | 37 +- .../src/string_contains.rs | 42 +- .../src/string_ends_with.rs | 56 ++- .../src/string_starts_with.rs | 39 +- .../datafusion-ext-functions/src/lib.rs | 7 +- .../src/spark_check_overflow.rs | 28 +- .../src/spark_get_json_object.rs | 38 +- .../src/spark_make_array.rs | 39 +- .../src/spark_make_decimal.rs | 23 +- .../src/spark_murmur3_hash.rs | 12 +- .../src/spark_null_if_zero.rs | 21 +- .../src/spark_strings.rs | 33 +- .../src/spark_unscaled_value.rs | 21 +- native-engine/datafusion-ext-plans/Cargo.toml | 5 +- .../datafusion-ext-plans/src/agg/agg_buf.rs | 53 ++- .../src/agg/agg_context.rs | 91 +++- .../src/agg/agg_tables.rs | 300 ++++++++---- .../datafusion-ext-plans/src/agg/avg.rs | 37 +- .../src/agg/collect_list.rs | 25 +- .../src/agg/collect_set.rs | 25 +- .../datafusion-ext-plans/src/agg/count.rs | 25 +- .../datafusion-ext-plans/src/agg/first.rs | 29 +- .../src/agg/first_ignores_null.rs | 33 +- .../datafusion-ext-plans/src/agg/maxmin.rs | 43 +- .../datafusion-ext-plans/src/agg/mod.rs | 19 +- .../datafusion-ext-plans/src/agg/sum.rs | 34 +- .../datafusion-ext-plans/src/agg_exec.rs | 408 +++++++--------- .../src/broadcast_join_exec.rs | 77 +-- .../src/broadcast_nested_loop_join_exec.rs | 36 +- .../src/common/batch_statisitcs.rs | 17 +- .../src/common/cached_exprs_evaluator.rs | 45 +- .../src/common/column_pruning.rs | 24 +- .../datafusion-ext-plans/src/common/mod.rs | 15 +- .../datafusion-ext-plans/src/common/output.rs | 178 +++---- .../datafusion-ext-plans/src/debug_exec.rs | 40 +- .../src/empty_partitions_exec.rs | 35 +- .../datafusion-ext-plans/src/expand_exec.rs | 78 +-- .../src/ffi_reader_exec.rs | 29 +- .../datafusion-ext-plans/src/filter_exec.rs | 79 +-- .../src/generate/explode.rs | 12 +- .../datafusion-ext-plans/src/generate/mod.rs | 17 +- .../datafusion-ext-plans/src/generate_exec.rs | 92 ++-- .../src/ipc_reader_exec.rs | 48 +- .../src/ipc_writer_exec.rs | 99 ++-- native-engine/datafusion-ext-plans/src/lib.rs | 2 +- .../datafusion-ext-plans/src/limit_exec.rs | 56 ++- .../memory_manager.rs => memmgr/mod.rs} | 9 +- .../src/{common => memmgr}/onheap_spill.rs | 18 +- .../datafusion-ext-plans/src/parquet_exec.rs | 77 ++- .../src/parquet_sink_exec.rs | 47 +- .../datafusion-ext-plans/src/project_exec.rs | 76 ++- .../src/rename_columns_exec.rs | 38 +- .../src/rss_shuffle_writer_exec.rs | 74 ++- .../src/shuffle/bucket_repartitioner.rs | 49 +- .../datafusion-ext-plans/src/shuffle/mod.rs | 39 +- .../datafusion-ext-plans/src/shuffle/rss.rs | 3 +- .../src/shuffle/rss_bucket_repartitioner.rs | 29 +- .../src/shuffle/rss_single_repartitioner.rs | 10 +- .../src/shuffle/rss_sort_repartitioner.rs | 35 +- .../src/shuffle/single_repartitioner.rs | 18 +- .../src/shuffle/sort_repartitioner.rs | 40 +- .../src/shuffle_writer_exec.rs | 73 ++- .../datafusion-ext-plans/src/sort_exec.rs | 166 +++---- .../src/sort_merge_join_exec.rs | 110 ++--- .../datafusion-ext-plans/src/window/mod.rs | 24 +- .../src/window/processors/agg_processor.rs | 20 +- .../src/window/processors/rank_processor.rs | 15 +- .../window/processors/row_number_processor.rs | 13 +- .../src/window/window_context.rs | 18 +- .../datafusion-ext-plans/src/window_exec.rs | 137 +++--- pom.xml | 13 +- rustfmt.toml | 16 +- .../org/apache/spark/sql/blaze/BlazeConf.java | 89 ++-- .../sql/blaze/BlazeCallNativeWrapper.scala | 112 +++-- .../spark/sql/blaze/NativeConverters.scala | 16 +- .../apache/spark/sql/blaze/NativeHelper.scala | 2 +- .../arrowio/ArrowFFIExportIterator.scala | 2 +- .../ArrowFFIStreamImportIterator.scala | 2 +- .../blaze/plan/ConvertToNativeBase.scala | 2 +- .../execution/blaze/plan/NativeAggBase.scala | 15 +- .../plan/NativeBroadcastExchangeBase.scala | 2 +- .../blaze/plan/NativeBroadcastJoinBase.scala | 4 +- .../blaze/plan/NativeSortMergeJoinBase.scala | 2 +- 119 files changed, 3011 insertions(+), 2559 deletions(-) create mode 100644 native-engine/blaze-jni-bridge/src/conf.rs rename native-engine/{datafusion-ext-plans/src/common => datafusion-ext-commons/src}/bytes_arena.rs (96%) delete mode 100644 native-engine/datafusion-ext-commons/src/ffi.rs rename native-engine/{datafusion-ext-plans/src/common => datafusion-ext-commons/src}/rdxsort.rs (98%) rename native-engine/{datafusion-ext-plans/src/common => datafusion-ext-commons/src}/slim_bytes.rs (99%) rename native-engine/datafusion-ext-plans/src/{common/memory_manager.rs => memmgr/mod.rs} (99%) rename native-engine/datafusion-ext-plans/src/{common => memmgr}/onheap_spill.rs (94%) diff --git a/Cargo.lock b/Cargo.lock index 86175b27..8375c2ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -96,7 +96,7 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "ahash", "arrow-arith", @@ -117,7 +117,7 @@ dependencies = [ [[package]] name = "arrow-arith" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -131,7 +131,7 @@ dependencies = [ [[package]] name = "arrow-array" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "ahash", "arrow-buffer", @@ -148,7 +148,7 @@ dependencies = [ [[package]] name = "arrow-buffer" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "half", "num", @@ -157,7 +157,7 @@ dependencies = [ [[package]] name = "arrow-cast" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -174,7 +174,7 @@ dependencies = [ [[package]] name = "arrow-csv" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -192,7 +192,7 @@ dependencies = [ [[package]] name = "arrow-data" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-buffer", "arrow-schema", @@ -203,7 +203,7 @@ dependencies = [ [[package]] name = "arrow-ipc" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -216,7 +216,7 @@ dependencies = [ [[package]] name = "arrow-json" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -235,7 +235,7 @@ dependencies = [ [[package]] name = "arrow-ord" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -249,7 +249,7 @@ dependencies = [ [[package]] name = "arrow-row" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "ahash", "arrow-array", @@ -263,7 +263,7 @@ dependencies = [ [[package]] name = "arrow-schema" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "bitflags 2.4.0", "serde", @@ -272,7 +272,7 @@ dependencies = [ [[package]] name = "arrow-select" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -284,7 +284,7 @@ dependencies = [ [[package]] name = "arrow-string" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "arrow-array", "arrow-buffer", @@ -740,7 +740,7 @@ dependencies = [ [[package]] name = "datafusion" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "ahash", "arrow", @@ -788,7 +788,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "arrow", "arrow-array", @@ -811,7 +811,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "arrow", "dashmap", @@ -830,7 +830,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "ahash", "arrow", @@ -861,6 +861,8 @@ dependencies = [ "once_cell", "paste", "postcard", + "rand", + "slimmer_box", "tempfile", "thrift", "tokio", @@ -929,7 +931,6 @@ dependencies = [ "panic-message", "parking_lot", "paste", - "rand", "slimmer_box", "tempfile", "tokio", @@ -939,7 +940,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "arrow", "async-trait", @@ -956,7 +957,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "ahash", "arrow", @@ -990,7 +991,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "30.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=312358e45#312358e453cc7bf8cf5bfbef15d6a1caefed5c42" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=38077432e#38077432e66bab447b7aa138b9be9608f252e9c2" dependencies = [ "arrow", "arrow-schema", @@ -1814,7 +1815,7 @@ dependencies = [ [[package]] name = "parquet" version = "45.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=1173507b57#1173507b577dbce1bdf7621cb25604fd2385fd20" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=5a6d98d183#5a6d98d1832d9a0169e291db0a0ba0a88045e701" dependencies = [ "ahash", "arrow-array", diff --git a/Cargo.toml b/Cargo.toml index 8bacdd69..477f21b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ strip = false [profile.pre] inherits = "release" -incremental = true +#incremental = true opt-level = 1 lto = false codegen-units = 16 diff --git a/dev/mvn-build-helper/assembly/pom.xml b/dev/mvn-build-helper/assembly/pom.xml index 7eca9ef0..6cd5822b 100644 --- a/dev/mvn-build-helper/assembly/pom.xml +++ b/dev/mvn-build-helper/assembly/pom.xml @@ -69,6 +69,9 @@ org/apache/commons/codec/**/* org/apache/commons/compress/**/* org/slf4j/**/* + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA diff --git a/native-engine/blaze-jni-bridge/src/conf.rs b/native-engine/blaze-jni-bridge/src/conf.rs new file mode 100644 index 00000000..dd476ed6 --- /dev/null +++ b/native-engine/blaze-jni-bridge/src/conf.rs @@ -0,0 +1,75 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use datafusion::common::Result; + +use crate::{jni_call_static, jni_new_string}; + +macro_rules! define_conf { + ($conftype:ty, $name:ident) => { + #[allow(non_camel_case_types)] + pub struct $name; + impl $conftype for $name { + fn key(&self) -> &'static str { + stringify!($name) + } + } + }; +} + +define_conf!(IntConf, BATCH_SIZE); +define_conf!(DoubleConf, MEMORY_FRACTION); +define_conf!(BooleanConf, SMJ_INEQUALITY_JOIN_ENABLE); +define_conf!(BooleanConf, BHJ_FALLBACKS_TO_SMJ_ENABLE); +define_conf!(IntConf, BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD); +define_conf!(IntConf, BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD); +define_conf!(BooleanConf, CASE_CONVERT_FUNCTIONS_ENABLE); +define_conf!(IntConf, UDF_WRAPPER_NUM_THREADS); +define_conf!(BooleanConf, INPUT_BATCH_STATISTICS_ENABLE); +define_conf!(BooleanConf, IGNORE_CORRUPTED_FILES); +define_conf!(BooleanConf, PARTIAL_AGG_SKIPPING_ENABLE); +define_conf!(DoubleConf, PARTIAL_AGG_SKIPPING_RATIO); +define_conf!(IntConf, PARTIAL_AGG_SKIPPING_MIN_ROWS); + +pub trait BooleanConf { + fn key(&self) -> &'static str; + fn value(&self) -> Result { + let key = jni_new_string!(self.key())?; + jni_call_static!(BlazeConf.booleanConf(key.as_obj()) -> bool) + } +} + +pub trait IntConf { + fn key(&self) -> &'static str; + fn value(&self) -> Result { + let key = jni_new_string!(self.key())?; + jni_call_static!(BlazeConf.intConf(key.as_obj()) -> i32) + } +} + +pub trait LongConf { + fn key(&self) -> &'static str; + fn value(&self) -> Result { + let key = jni_new_string!(self.key())?; + jni_call_static!(BlazeConf.longConf(key.as_obj()) -> i64) + } +} + +pub trait DoubleConf { + fn key(&self) -> &'static str; + fn value(&self) -> Result { + let key = jni_new_string!(self.key())?; + jni_call_static!(BlazeConf.doubleConf(key.as_obj()) -> f64) + } +} diff --git a/native-engine/blaze-jni-bridge/src/jni_bridge.rs b/native-engine/blaze-jni-bridge/src/jni_bridge.rs index 10f0e685..d8c2f82b 100644 --- a/native-engine/blaze-jni-bridge/src/jni_bridge.rs +++ b/native-engine/blaze-jni-bridge/src/jni_bridge.rs @@ -13,21 +13,16 @@ // limitations under the License. pub use datafusion; -pub use jni; -pub use jni::errors::Result as JniResult; -pub use jni::objects::JClass; -pub use jni::objects::JMethodID; -pub use jni::objects::JObject; -pub use jni::objects::JStaticMethodID; -pub use jni::objects::JValue; -pub use jni::signature::Primitive; -pub use jni::signature::ReturnType; -pub use jni::sys::jvalue; -pub use jni::JNIEnv; -pub use jni::JavaVM; -pub use paste::paste; - +pub use jni::{ + self, + errors::Result as JniResult, + objects::{JClass, JMethodID, JObject, JStaticMethodID, JValue}, + signature::{Primitive, ReturnType}, + sys::jvalue, + JNIEnv, JavaVM, +}; use once_cell::sync::OnceCell; +pub use paste::paste; thread_local! { pub static THREAD_JNIENV: once_cell::unsync::Lazy> = @@ -1123,22 +1118,14 @@ impl<'a> SparkMetricNode<'a> { #[allow(non_snake_case)] pub struct BlazeConf<'a> { pub class: JClass<'a>, - pub method_batchSize: JStaticMethodID, - pub method_batchSize_ret: ReturnType, - pub method_memoryFraction: JStaticMethodID, - pub method_memoryFraction_ret: ReturnType, - pub method_enableBhjFallbacksToSmj: JStaticMethodID, - pub method_enableBhjFallbacksToSmj_ret: ReturnType, - pub method_bhjFallbacksToSmjRowsThreshold: JStaticMethodID, - pub method_bhjFallbacksToSmjRowsThreshold_ret: ReturnType, - pub method_bhjFallbacksToSmjMemThreshold: JStaticMethodID, - pub method_bhjFallbacksToSmjMemThreshold_ret: ReturnType, - pub method_udfWrapperNumThreads: JStaticMethodID, - pub method_udfWrapperNumThreads_ret: ReturnType, - pub method_enableInputBatchStatistics: JStaticMethodID, - pub method_enableInputBatchStatistics_ret: ReturnType, - pub method_ignoreCorruptedFiles: JStaticMethodID, - pub method_ignoreCorruptedFiles_ret: ReturnType, + pub method_booleanConf: JStaticMethodID, + pub method_booleanConf_ret: ReturnType, + pub method_intConf: JStaticMethodID, + pub method_intConf_ret: ReturnType, + pub method_longConf: JStaticMethodID, + pub method_longConf_ret: ReturnType, + pub method_doubleConf: JStaticMethodID, + pub method_doubleConf_ret: ReturnType, } impl<'a> BlazeConf<'_> { @@ -1148,36 +1135,22 @@ impl<'a> BlazeConf<'_> { let class = get_global_jclass(env, Self::SIG_TYPE)?; Ok(BlazeConf { class, - method_batchSize: env.get_static_method_id(class, "batchSize", "()I").unwrap(), - method_batchSize_ret: ReturnType::Primitive(Primitive::Int), - method_memoryFraction: env - .get_static_method_id(class, "memoryFraction", "()D") - .unwrap(), - method_memoryFraction_ret: ReturnType::Primitive(Primitive::Double), - method_enableBhjFallbacksToSmj: env - .get_static_method_id(class, "enableBhjFallbacksToSmj", "()Z") + method_booleanConf: env + .get_static_method_id(class, "booleanConf", "(Ljava/lang/String;)Z") .unwrap(), - method_enableBhjFallbacksToSmj_ret: ReturnType::Primitive(Primitive::Boolean), - method_bhjFallbacksToSmjRowsThreshold: env - .get_static_method_id(class, "bhjFallbacksToSmjRowsThreshold", "()I") + method_booleanConf_ret: ReturnType::Primitive(Primitive::Boolean), + method_intConf: env + .get_static_method_id(class, "intConf", "(Ljava/lang/String;)I") .unwrap(), - method_bhjFallbacksToSmjRowsThreshold_ret: ReturnType::Primitive(Primitive::Int), - method_bhjFallbacksToSmjMemThreshold: env - .get_static_method_id(class, "bhjFallbacksToSmjMemThreshold", "()I") + method_intConf_ret: ReturnType::Primitive(Primitive::Int), + method_longConf: env + .get_static_method_id(class, "longConf", "(Ljava/lang/String;)J") .unwrap(), - method_bhjFallbacksToSmjMemThreshold_ret: ReturnType::Primitive(Primitive::Int), - method_udfWrapperNumThreads: env - .get_static_method_id(class, "udfWrapperNumThreads", "()I") + method_longConf_ret: ReturnType::Primitive(Primitive::Long), + method_doubleConf: env + .get_static_method_id(class, "doubleConf", "(Ljava/lang/String;)D") .unwrap(), - method_udfWrapperNumThreads_ret: ReturnType::Primitive(Primitive::Int), - method_enableInputBatchStatistics: env - .get_static_method_id(class, "enableInputBatchStatistics", "()Z") - .unwrap(), - method_enableInputBatchStatistics_ret: ReturnType::Primitive(Primitive::Boolean), - method_ignoreCorruptedFiles: env - .get_static_method_id(class, "ignoreCorruptedFiles", "()Z") - .unwrap(), - method_ignoreCorruptedFiles_ret: ReturnType::Primitive(Primitive::Boolean), + method_doubleConf_ret: ReturnType::Primitive(Primitive::Double), }) } } @@ -1241,8 +1214,10 @@ pub struct BlazeCallNativeWrapper<'a> { pub method_getRawTaskDefinition_ret: ReturnType, pub method_getMetrics: JMethodID, pub method_getMetrics_ret: ReturnType, - pub method_setArrowFFIStreamPtr: JMethodID, - pub method_setArrowFFIStreamPtr_ret: ReturnType, + pub method_importSchema: JMethodID, + pub method_importSchema_ret: ReturnType, + pub method_importBatch: JMethodID, + pub method_importBatch_ret: ReturnType, pub method_setError: JMethodID, pub method_setError_ret: ReturnType, } @@ -1257,10 +1232,6 @@ impl<'a> BlazeCallNativeWrapper<'a> { .get_method_id(class, "getRawTaskDefinition", "()[B") .unwrap(), method_getRawTaskDefinition_ret: ReturnType::Array, - method_setArrowFFIStreamPtr: env - .get_method_id(class, "setArrowFFIStreamPtr", "(J)V") - .unwrap(), - method_setArrowFFIStreamPtr_ret: ReturnType::Primitive(Primitive::Void), method_getMetrics: env .get_method_id( class, @@ -1269,6 +1240,10 @@ impl<'a> BlazeCallNativeWrapper<'a> { ) .unwrap(), method_getMetrics_ret: ReturnType::Object, + method_importSchema: env.get_method_id(class, "importSchema", "(J)V").unwrap(), + method_importSchema_ret: ReturnType::Primitive(Primitive::Void), + method_importBatch: env.get_method_id(class, "importBatch", "(J)V").unwrap(), + method_importBatch_ret: ReturnType::Primitive(Primitive::Void), method_setError: env .get_method_id(class, "setError", "(Ljava/lang/Throwable;)V") .unwrap(), diff --git a/native-engine/blaze-jni-bridge/src/lib.rs b/native-engine/blaze-jni-bridge/src/lib.rs index 01e2afd9..108d7e0c 100644 --- a/native-engine/blaze-jni-bridge/src/lib.rs +++ b/native-engine/blaze-jni-bridge/src/lib.rs @@ -13,10 +13,13 @@ // limitations under the License. use datafusion::common::Result; -use jni::objects::GlobalRef; -use jni::sys::{JNI_FALSE, JNI_TRUE}; +use jni::{ + objects::GlobalRef, + sys::{JNI_FALSE, JNI_TRUE}, +}; use once_cell::sync::OnceCell; +pub mod conf; pub mod jni_bridge; pub fn is_jni_bridge_inited() -> bool { diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 96430ff8..64b640a1 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -584,6 +584,7 @@ message AggExecNode { repeated string grouping_expr_name = 6; repeated string agg_expr_name = 7; uint64 initial_input_buffer_offset = 8; + bool supports_partial_skipping = 9; } enum AggExecMode { diff --git a/native-engine/blaze-serde/src/error.rs b/native-engine/blaze-serde/src/error.rs index 4b473086..0a4d66be 100644 --- a/native-engine/blaze-serde/src/error.rs +++ b/native-engine/blaze-serde/src/error.rs @@ -114,10 +114,12 @@ pub trait FromOptionalField { /// on the contained data, returning any error encountered fn optional(self) -> std::result::Result, PlanSerDeError>; - /// Converts an optional protobuf field to a different type, returning an error if None + /// Converts an optional protobuf field to a different type, returning an + /// error if None /// - /// Returns `Error::MissingRequiredField` if None, otherwise calls [`FromField::field`] - /// on the contained data, returning any error encountered + /// Returns `Error::MissingRequiredField` if None, otherwise calls + /// [`FromField::field`] on the contained data, returning any error + /// encountered fn required(self, field: impl Into) -> std::result::Result; } diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 6bfe86ea..6675bb89 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -14,79 +14,83 @@ //! Serde code to convert from protocol buffers to Rust data structures. -use std::convert::{TryFrom, TryInto}; -use std::sync::Arc; +use std::{ + convert::{TryFrom, TryInto}, + sync::Arc, +}; use arrow::datatypes::{FieldRef, SchemaRef}; -use base64::prelude::BASE64_URL_SAFE_NO_PAD; -use base64::Engine; +use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use chrono::DateTime; -use datafusion::datasource::listing::{FileRange, PartitionedFile}; -use datafusion::datasource::object_store::ObjectStoreUrl; -use datafusion::datasource::physical_plan::FileScanConfig; -use datafusion::error::DataFusionError; -use datafusion::execution::context::ExecutionProps; -use datafusion::logical_expr::{BuiltinScalarFunction, Operator}; -use datafusion::physical_expr::expressions::{LikeExpr, SCAndExpr, SCOrExpr}; -use datafusion::physical_expr::{functions, ScalarFunctionExpr}; -use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; -use datafusion::physical_plan::sorts::sort::SortOptions; -use datafusion::physical_plan::union::UnionExec; -use datafusion::physical_plan::{ - expressions as phys_expr, - expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, - NegativeExpr, NotExpr, PhysicalSortExpr, +use datafusion::{ + datasource::{ + listing::{FileRange, PartitionedFile}, + object_store::ObjectStoreUrl, + physical_plan::FileScanConfig, + }, + error::DataFusionError, + execution::context::ExecutionProps, + logical_expr::{BuiltinScalarFunction, Operator}, + physical_expr::{ + expressions::{LikeExpr, SCAndExpr, SCOrExpr}, + functions, ScalarFunctionExpr, + }, + physical_plan::{ + expressions as phys_expr, + expressions::{ + BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, + NegativeExpr, NotExpr, PhysicalSortExpr, + }, + joins::utils::{ColumnIndex, JoinFilter}, + sorts::sort::SortOptions, + union::UnionExec, + ColumnStatistics, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }, - Partitioning, }; -use datafusion::physical_plan::{ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics}; - use datafusion_ext_commons::streams::ipc_stream::IpcReadMode; -use datafusion_ext_plans::agg::{ - create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr, +use datafusion_ext_exprs::{ + cast::TryCastExpr, get_indexed_field::GetIndexedFieldExpr, get_map_value::GetMapValueExpr, + named_struct::NamedStructExpr, spark_scalar_subquery_wrapper::SparkScalarSubqueryWrapperExpr, + spark_udf_wrapper::SparkUDFWrapperExpr, string_contains::StringContainsExpr, + string_ends_with::StringEndsWithExpr, string_starts_with::StringStartsWithExpr, +}; +use datafusion_ext_plans::{ + agg::{create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr}, + agg_exec::AggExec, + broadcast_join_exec::BroadcastJoinExec, + broadcast_nested_loop_join_exec::BroadcastNestedLoopJoinExec, + debug_exec::DebugExec, + empty_partitions_exec::EmptyPartitionsExec, + expand_exec::ExpandExec, + ffi_reader_exec::FFIReaderExec, + filter_exec::FilterExec, + generate::create_generator, + generate_exec::GenerateExec, + ipc_reader_exec::IpcReaderExec, + ipc_writer_exec::IpcWriterExec, + limit_exec::LimitExec, + parquet_exec::ParquetExec, + parquet_sink_exec::ParquetSinkExec, + project_exec::ProjectExec, + rename_columns_exec::RenameColumnsExec, + rss_shuffle_writer_exec::RssShuffleWriterExec, + shuffle_writer_exec::ShuffleWriterExec, + sort_exec::SortExec, + sort_merge_join_exec::SortMergeJoinExec, + window::{WindowExpr, WindowFunction, WindowRankType}, + window_exec::WindowExec, +}; +use object_store::{path::Path, ObjectMeta}; + +use crate::{ + convert_box_required, convert_required, + error::PlanSerDeError, + from_proto_binary_op, into_required, proto_error, protobuf, + protobuf::{ + physical_expr_node::ExprType, physical_plan_node::PhysicalPlanType, GenerateFunction, + }, + Schema, }; -use datafusion_ext_plans::agg_exec::AggExec; -use datafusion_ext_plans::broadcast_join_exec::BroadcastJoinExec; -use datafusion_ext_plans::debug_exec::DebugExec; -use datafusion_ext_plans::empty_partitions_exec::EmptyPartitionsExec; -use datafusion_ext_plans::expand_exec::ExpandExec; -use datafusion_ext_plans::ffi_reader_exec::FFIReaderExec; -use datafusion_ext_plans::filter_exec::FilterExec; -use datafusion_ext_plans::ipc_reader_exec::IpcReaderExec; -use datafusion_ext_plans::ipc_writer_exec::IpcWriterExec; -use datafusion_ext_plans::limit_exec::LimitExec; -use datafusion_ext_plans::parquet_exec::ParquetExec; -use datafusion_ext_plans::project_exec::ProjectExec; -use datafusion_ext_plans::rename_columns_exec::RenameColumnsExec; -use datafusion_ext_plans::rss_shuffle_writer_exec::RssShuffleWriterExec; -use datafusion_ext_plans::shuffle_writer_exec::ShuffleWriterExec; -use datafusion_ext_plans::sort_exec::SortExec; -use datafusion_ext_plans::sort_merge_join_exec::SortMergeJoinExec; -use object_store::path::Path; -use object_store::ObjectMeta; - -use crate::error::PlanSerDeError; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::GenerateFunction; -use crate::{convert_box_required, convert_required, into_required, protobuf, Schema}; -use crate::{from_proto_binary_op, proto_error}; -use datafusion_ext_exprs::cast::TryCastExpr; -use datafusion_ext_exprs::get_indexed_field::GetIndexedFieldExpr; -use datafusion_ext_exprs::get_map_value::GetMapValueExpr; -use datafusion_ext_exprs::named_struct::NamedStructExpr; -use datafusion_ext_exprs::spark_scalar_subquery_wrapper::SparkScalarSubqueryWrapperExpr; -use datafusion_ext_exprs::spark_udf_wrapper::SparkUDFWrapperExpr; -use datafusion_ext_exprs::string_contains::StringContainsExpr; -use datafusion_ext_exprs::string_ends_with::StringEndsWithExpr; -use datafusion_ext_exprs::string_starts_with::StringStartsWithExpr; -use datafusion_ext_plans::broadcast_nested_loop_join_exec::BroadcastNestedLoopJoinExec; -use datafusion_ext_plans::generate::create_generator; -use datafusion_ext_plans::generate_exec::GenerateExec; -use datafusion_ext_plans::parquet_sink_exec::ParquetSinkExec; -use datafusion_ext_plans::window::{WindowExpr, WindowFunction, WindowRankType}; -use datafusion_ext_plans::window_exec::WindowExec; fn bind( expr_in: Arc, @@ -563,6 +567,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { physical_groupings, physical_aggs, agg.initial_input_buffer_offset as usize, + agg.supports_partial_skipping, input, )?)) } diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index cc6f3339..dec025da 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs @@ -12,14 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::error::PlanSerDeError; +use std::sync::Arc; + use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit}; -use datafusion::logical_expr::Operator; -use datafusion::physical_plan::joins::utils::JoinSide; -use datafusion::prelude::JoinType; -use datafusion::scalar::ScalarValue; +use datafusion::{ + logical_expr::Operator, physical_plan::joins::utils::JoinSide, prelude::JoinType, + scalar::ScalarValue, +}; use datafusion_ext_plans::agg::AggFunction; -use std::sync::Arc; + +use crate::error::PlanSerDeError; // include the generated protobuf source as a submodule #[allow(clippy::all)] @@ -306,8 +308,10 @@ impl TryInto for &protobuf::arrow_type::ArrowTypeEnu .ok_or_else(|| proto_error("Protobuf deserialization error: Map message missing required field 'value_type'"))? .as_ref(); - let vec_field = - vec![Arc::new(key_type.try_into()?), Arc::new(value_type.try_into()?)]; + let vec_field = vec![ + Arc::new(key_type.try_into()?), + Arc::new(value_type.try_into()?), + ]; let fields = Arc::new(Field::new( "entries", DataType::Struct(vec_field.into()), diff --git a/native-engine/blaze/src/exec.rs b/native-engine/blaze/src/exec.rs index fda8ef63..d05be263 100644 --- a/native-engine/blaze/src/exec.rs +++ b/native-engine/blaze/src/exec.rs @@ -12,26 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::rt::NativeExecutionRuntime; -use crate::{handle_unwinded_scope, SESSION}; -use blaze_jni_bridge::jni_bridge::JavaClasses; -use blaze_jni_bridge::*; +use std::sync::Arc; + +use blaze_jni_bridge::{ + conf::{DoubleConf, IntConf}, + jni_bridge::JavaClasses, + *, +}; use blaze_serde::protobuf::TaskDefinition; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::execution::disk_manager::DiskManagerConfig; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::physical_plan::{displayable, ExecutionPlan}; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_ext_plans::common::memory_manager::MemManager; -use jni::objects::JClass; -use jni::objects::JObject; -use jni::JNIEnv; +use datafusion::{ + common::Result, + error::DataFusionError, + execution::{ + disk_manager::DiskManagerConfig, + runtime_env::{RuntimeConfig, RuntimeEnv}, + }, + physical_plan::{displayable, ExecutionPlan}, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_ext_plans::memmgr::MemManager; +use jni::{ + objects::{JClass, JObject}, + JNIEnv, +}; use log::LevelFilter; use once_cell::sync::OnceCell; use prost::Message; use simplelog::{ColorChoice, ConfigBuilder, TermLogger, TerminalMode, ThreadLogMode}; -use std::sync::Arc; + +use crate::{handle_unwinded_scope, rt::NativeExecutionRuntime, SESSION}; fn init_logging() { static LOGGING_INIT: OnceCell<()> = OnceCell::new(); @@ -66,8 +75,8 @@ pub extern "system" fn Java_org_apache_spark_sql_blaze_JniBridge_initNative( // init datafusion session context SESSION.get_or_try_init(|| { let max_memory = executor_memory_overhead as usize; - let memory_fraction = jni_call_static!(BlazeConf.memoryFraction() -> f64)?; - let batch_size = jni_call_static!(BlazeConf.batchSize() -> i32)? as usize; + let memory_fraction = conf::MEMORY_FRACTION.value()?; + let batch_size = conf::BATCH_SIZE.value()? as usize; MemManager::init((max_memory as f64 * memory_fraction) as usize); let session_config = SessionConfig::new().with_batch_size(batch_size); diff --git a/native-engine/blaze/src/lib.rs b/native-engine/blaze/src/lib.rs index 490dfcb5..2c7c0510 100644 --- a/native-engine/blaze/src/lib.rs +++ b/native-engine/blaze/src/lib.rs @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{any::Any, error::Error, fmt::Debug, panic::AssertUnwindSafe}; + use blaze_jni_bridge::*; use datafusion::prelude::SessionContext; use jni::objects::{JObject, JThrowable}; use once_cell::sync::OnceCell; -use std::any::Any; -use std::error::Error; -use std::fmt::Debug; -use std::panic::AssertUnwindSafe; mod exec; mod metrics; diff --git a/native-engine/blaze/src/metrics.rs b/native-engine/blaze/src/metrics.rs index a685fd51..ecb47fb2 100644 --- a/native-engine/blaze/src/metrics.rs +++ b/native-engine/blaze/src/metrics.rs @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use blaze_jni_bridge::{jni_call, jni_new_string}; -use datafusion::common::Result; -use datafusion::physical_plan::ExecutionPlan; +use datafusion::{common::Result, physical_plan::ExecutionPlan}; use jni::objects::JObject; -use std::sync::Arc; pub fn update_spark_metric_node( metric_node: JObject, diff --git a/native-engine/blaze/src/rt.rs b/native-engine/blaze/src/rt.rs index 14928134..0a529d75 100644 --- a/native-engine/blaze/src/rt.rs +++ b/native-engine/blaze/src/rt.rs @@ -12,36 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::handle_unwinded_scope; -use crate::metrics::update_spark_metric_node; -use arrow::ffi_stream::FFI_ArrowArrayStream; -use blaze_jni_bridge::is_task_running; -use blaze_jni_bridge::jni_bridge::JavaClasses; +use std::{panic::AssertUnwindSafe, sync::Arc}; + +use arrow::{ + array::{Array, StructArray}, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, +}; use blaze_jni_bridge::{ - jni_call, jni_call_static, jni_exception_check, jni_exception_occurred, jni_new_global_ref, - jni_new_object, jni_new_string, + is_task_running, jni_bridge::JavaClasses, jni_call, jni_call_static, jni_exception_check, + jni_exception_occurred, jni_new_global_ref, jni_new_object, jni_new_string, }; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::Time; -use datafusion::physical_plan::{ExecutionPlan, RecordBatchStream}; -use datafusion_ext_commons::ffi::MpscBatchReader; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; +use datafusion::{ + common::Result, + error::DataFusionError, + execution::context::TaskContext, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + ExecutionPlan, + }, +}; +use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; use datafusion_ext_plans::common::output::WrappedRecordBatchSender; use futures::{FutureExt, StreamExt}; use jni::objects::{GlobalRef, JObject}; -use std::panic::AssertUnwindSafe; -use std::sync::Arc; use tokio::runtime::Runtime; +use crate::{handle_unwinded_scope, metrics::update_spark_metric_node}; + pub struct NativeExecutionRuntime { native_wrapper: GlobalRef, plan: Arc, task_context: Arc, partition: usize, + ffi_schema: Arc, rt: Runtime, - ffi_stream: Box, } impl NativeExecutionRuntime { @@ -51,33 +55,21 @@ impl NativeExecutionRuntime { partition: usize, context: Arc, ) -> Result { - let batch_size = context.session_config().batch_size(); - // execute plan to output stream let stream = plan.execute(partition, context.clone())?; + let schema = stream.schema(); // coalesce - let coalesce_compute_time = Time::new(); - let mut stream = Box::pin(CoalesceStream::new( + let mut stream = context.coalesce_with_default_batch_size( stream, - batch_size, - coalesce_compute_time, - )); - - // create mpsc channel for collecting batches - let (sender, receiver) = std::sync::mpsc::sync_channel(1); - - // create RecordBatchReader - let batch_reader = Box::new(MpscBatchReader { - schema: stream.schema(), - receiver, - }); + &BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), partition), + )?; - // create and export FFI_ArrowArrayStream - let ffi_stream = Box::new(FFI_ArrowArrayStream::new(batch_reader)); - let ffi_stream_ptr = &*ffi_stream as *const FFI_ArrowArrayStream; + // init ffi schema + let ffi_schema = Arc::new(FFI_ArrowSchema::try_from(schema.as_ref())?); jni_call!(BlazeCallNativeWrapper(native_wrapper.as_obj()) - .setArrowFFIStreamPtr(ffi_stream_ptr as i64) -> ())?; + .importSchema(ffi_schema.as_ref() as *const FFI_ArrowSchema as i64) -> () + )?; // create tokio runtime // propagate classloader and task context to spawned children threads @@ -100,13 +92,14 @@ impl NativeExecutionRuntime { plan, partition, rt, - ffi_stream, + ffi_schema, task_context: context, }; // spawn batch producer - let sender_cloned = sender.clone(); + let native_wrapper_cloned = native_wrapper.clone(); let consume_stream = move || async move { + let native_wrapper = native_wrapper_cloned; while let Some(batch) = AssertUnwindSafe(stream.next()) .catch_unwind() .await @@ -118,73 +111,42 @@ impl NativeExecutionRuntime { .transpose() .map_err(|err| DataFusionError::Execution(format!("{}", err)))? { - sender.send(Some(Ok(batch))).map_err(|err| { - DataFusionError::Execution(format!("sending batch error: {}", err)) - })?; + let ffi_array = FFI_ArrowArray::new(&StructArray::from(batch).into_data()); + jni_call!(BlazeCallNativeWrapper(native_wrapper.as_obj()) + .importBatch(&ffi_array as *const FFI_ArrowArray as i64, ) -> () + )?; } + jni_call!(BlazeCallNativeWrapper(native_wrapper.as_obj()).importBatch(0, 0) -> ())?; - sender.send(None).unwrap_or_else(|err| { - log::warn!( - "native execution [partition={}] completing channel error: {}", - partition, - err, - ); - }); - log::info!("native execution [partition={}] finished", partition); + log::info!("[partition={partition}] finished"); Ok::<_, DataFusionError>(()) }; nrt.rt.spawn(async move { let result = consume_stream().await; result.unwrap_or_else(|err| handle_unwinded_scope(|| -> Result<()> { let task_running = is_task_running(); - log::warn!( - "native execution [partition={}] broken (task_running: {}): {}", - partition, - task_running, - err, - ); if !task_running { log::warn!( - "native execution [partition={}] task completed/interrupted before native execution done", - partition, + "[partition={partition}] task completed/interrupted before native execution done", ); return Ok(()); } let cause = if jni_exception_check!()? { - log::error!( - "native execution [partition={}] panics with an java exception: {}", - partition, - err, - ); + log::error!("[partition={partition}] panics with an java exception: {err}"); Some(jni_exception_occurred!()?) } else { - log::error!( - "native execution [partition={}] panics: {}", - partition, - err, - ); + log::error!("[partition={partition}] panics: {err}"); None }; set_error( &native_wrapper, - &format!( - "native executing [partition={}] panics: {}", - partition, - err, - ), + &format!("[partition={partition}] panics: {err}"), cause.map(|e| e.as_obj()), )?; - - // terminate the MpscBatchReader after error is set - let _ = sender_cloned.send(None); - - log::info!( - "native execution [partition={}] exited abnormally.", - partition, - ); + log::info!("[partition={partition}] exited abnormally."); Ok::<_, DataFusionError>(()) })); }); @@ -194,7 +156,7 @@ impl NativeExecutionRuntime { pub fn finalize(self) { log::info!("native execution [partition={}] finalizing", self.partition); let _ = self.update_metrics(); - drop(self.ffi_stream); + drop(self.ffi_schema); drop(self.plan); WrappedRecordBatchSender::cancel_task(&self.task_context); // cancel all pending streams self.rt.shutdown_background(); diff --git a/native-engine/datafusion-ext-commons/Cargo.toml b/native-engine/datafusion-ext-commons/Cargo.toml index 9beefd9a..4b3e4342 100644 --- a/native-engine/datafusion-ext-commons/Cargo.toml +++ b/native-engine/datafusion-ext-commons/Cargo.toml @@ -24,7 +24,11 @@ num = "0.4.0" once_cell = "1.11.0" paste = "1.0.7" postcard = { version = "1.0.8", features = ["alloc"]} +slimmer_box = "0.6.5" tempfile = "3" thrift = "0.17.0" tokio = "1.34" zstd = "0.12.3" + +[dev-dependencies] +rand = "0.8.5" diff --git a/native-engine/datafusion-ext-commons/src/array_builder.rs b/native-engine/datafusion-ext-commons/src/array_builder.rs index a1d8dd0c..cc01b751 100644 --- a/native-engine/datafusion-ext-commons/src/array_builder.rs +++ b/native-engine/datafusion-ext-commons/src/array_builder.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::*; -use arrow::datatypes::*; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; +use arrow::{array::*, datatypes::*, error::Result as ArrowResult, record_batch::RecordBatch}; // NOTE: // we suggest not using this mod because array_builders do not support @@ -122,7 +119,7 @@ pub fn builder_extend( DataType::LargeBinary => append!(LargeBinary), DataType::Utf8 => append!(String), DataType::LargeUtf8 => append!(LargeString), - DataType::Decimal128(_, _) => append!(Decimal128), + DataType::Decimal128(..) => append!(Decimal128), dt => unimplemented!("data type not supported in builder_extend: {:?}", dt), } } diff --git a/native-engine/datafusion-ext-plans/src/common/bytes_arena.rs b/native-engine/datafusion-ext-commons/src/bytes_arena.rs similarity index 96% rename from native-engine/datafusion-ext-plans/src/common/bytes_arena.rs rename to native-engine/datafusion-ext-commons/src/bytes_arena.rs index 09ecfbb2..a31d004a 100644 --- a/native-engine/datafusion-ext-plans/src/common/bytes_arena.rs +++ b/native-engine/datafusion-ext-commons/src/bytes_arena.rs @@ -36,7 +36,8 @@ impl BytesArena { let cur_buf_len = self.cur_buf().len(); let len = bytes.len(); - // freeze current buf if it's almost full and has no enough space for the given bytes + // freeze current buf if it's almost full and has no enough space for the given + // bytes if cur_buf_len > BUF_CAPACITY_ALMOST_FULL && cur_buf_len + len > BUF_CAPACITY_TARGET { self.freeze_cur_buf(); } @@ -59,7 +60,7 @@ impl BytesArena { /// specialized for merging two parts in sort-exec /// works like an IntoIterator, free memory of visited items - pub(crate) fn specialized_get_and_drop_last(&mut self, addr: u64) -> &[u8] { + pub fn specialized_get_and_drop_last(&mut self, addr: u64) -> &[u8] { let (id, offset, len) = unapply_arena_addr(addr); if id > 0 && !self.bufs[id - 1].is_empty() { self.bufs[id - 1].truncate(0); // drop last buf diff --git a/native-engine/datafusion-ext-commons/src/cast.rs b/native-engine/datafusion-ext-commons/src/cast.rs index f158e88b..ba022c27 100644 --- a/native-engine/datafusion-ext-commons/src/cast.rs +++ b/native-engine/datafusion-ext-commons/src/cast.rs @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::*; -use arrow::datatypes::*; +use std::{str::FromStr, sync::Arc}; + +use arrow::{array::*, datatypes::*}; use bigdecimal::{FromPrimitive, ToPrimitive}; -use datafusion::common::cast::{as_float32_array, as_float64_array}; -use datafusion::common::{DataFusionError, Result}; +use datafusion::common::{ + cast::{as_float32_array, as_float64_array}, + DataFusionError, Result, +}; use num::{cast::AsPrimitive, Bounded, Integer, Signed}; use paste::paste; -use std::str::FromStr; -use std::sync::Arc; pub fn cast(array: &dyn Array, cast_type: &DataType) -> Result { return cast_impl(array, cast_type, false); @@ -71,15 +72,15 @@ pub fn cast_impl( // spark compatible string to integer cast try_cast_string_array_to_integer(array, cast_type)? } - (&DataType::Utf8, &DataType::Decimal128(_, _)) => { + (&DataType::Utf8, &DataType::Decimal128(..)) => { // spark compatible string to decimal cast try_cast_string_array_to_decimal(array, cast_type)? } - (&DataType::Decimal128(_, _), DataType::Utf8) => { + (&DataType::Decimal128(..), DataType::Utf8) => { // spark compatible decimal to string cast try_cast_decimal_array_to_string(array, cast_type)? } - (&DataType::Timestamp(_, _), DataType::Float64) => { + (&DataType::Timestamp(..), DataType::Float64) => { // timestamp to f64 = timestamp to i64 to f64, only used in agg.sum() arrow::compute::cast( &arrow::compute::cast(array, &DataType::Int64)?, @@ -92,14 +93,14 @@ pub fn cast_impl( } (&DataType::List(_), DataType::List(to_field)) => { let list = as_list_array(array); - let casted_items = cast_impl(list.values(), to_field.data_type(), match_struct_fields)?; + let items = cast_impl(list.values(), to_field.data_type(), match_struct_fields)?; make_array(ArrayData::try_new( DataType::List(to_field.clone()), list.len(), list.nulls().map(|nb| nb.buffer().clone()), list.offset(), list.to_data().buffers().to_vec(), - vec![casted_items.into_data()], + vec![items.into_data()], )?) } (&DataType::Struct(_), DataType::Struct(to_fields)) => { @@ -136,7 +137,7 @@ pub fn cast_impl( let mut null_column_name = vec![]; let casted_array = to_fields .iter() - .map(|field: &FieldRef| { + .map(|field| { let col = struct_.column_by_name(field.name().as_str()); if col.is_some() { cast_impl(col.unwrap(), field.data_type(), match_struct_fields) @@ -171,21 +172,20 @@ pub fn cast_impl( )?) } } - (&DataType::Map(_, _), &DataType::Map(ref to_entries_field, to_sorted)) => { + (&DataType::Map(..), &DataType::Map(ref to_entries_field, to_sorted)) => { let map = as_map_array(array); - let casted_entries = cast_impl( + let entries = cast_impl( map.entries(), to_entries_field.data_type(), match_struct_fields, )?; - make_array(ArrayData::try_new( DataType::Map(to_entries_field.clone(), to_sorted), map.len(), map.nulls().map(|nb| nb.buffer().clone()), map.offset(), map.to_data().buffers().to_vec(), - vec![casted_entries.into_data()], + vec![entries.into_data()], )?) } _ => { @@ -322,16 +322,18 @@ fn to_integer(input: &str) return None; }; - // We are going to process the new digit and accumulate the result. However, before doing - // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop. + // We are going to process the new digit and accumulate the result. However, + // before doing this, if the result is already smaller than the + // stopValue(Long.MIN_VALUE / radix), then result * 10 will definitely + // be smaller than minValue, and we can stop. if result < stop_value { return None; } result = result * radix - T::from_u8(digit).unwrap(); - // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we - // can just use `result > 0` to check overflow. If result overflows, we should stop. + // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop. if result > T::zero() { return None; } @@ -371,9 +373,10 @@ fn to_decimal(input: &str, precision: u8, scale: i8) -> Option { #[cfg(test)] mod test { - use crate::cast::*; use datafusion::common::cast::as_int32_array; + use crate::cast::*; + #[test] fn test_float_to_int() { let f64_array: ArrayRef = Arc::new(Float64Array::from_iter(vec![ diff --git a/native-engine/datafusion-ext-commons/src/ffi.rs b/native-engine/datafusion-ext-commons/src/ffi.rs deleted file mode 100644 index b5404f8b..00000000 --- a/native-engine/datafusion-ext-commons/src/ffi.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2022 The Blaze Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError; -use arrow::record_batch::{RecordBatch, RecordBatchReader}; -use blaze_jni_bridge::is_task_running; -use datafusion::common::Result; -use std::sync::mpsc::Receiver; - -/// RecordBatchReader for FFI_ArrowArrayStraem -pub struct MpscBatchReader { - pub schema: SchemaRef, - pub receiver: Receiver>>, -} - -impl RecordBatchReader for MpscBatchReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Iterator for MpscBatchReader { - type Item = Result; - - fn next(&mut self) -> Option { - self.receiver - .recv() - .unwrap_or_else(|err| { - // sender is unexpectedly died, terminate this stream - // errors should have been handled in sender side - let task_running = is_task_running(); - log::warn!( - "MpscBatchReader broken (task_running={}): {}", - task_running, - err, - ); - None - }) - .map(|result| result.map_err(|err| err.into())) - } -} diff --git a/native-engine/datafusion-ext-commons/src/hadoop_fs.rs b/native-engine/datafusion-ext-commons/src/hadoop_fs.rs index 7d58f0dd..fc619480 100644 --- a/native-engine/datafusion-ext-commons/src/hadoop_fs.rs +++ b/native-engine/datafusion-ext-commons/src/hadoop_fs.rs @@ -16,8 +16,7 @@ use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, jni_new_string, }; -use datafusion::error::Result; -use datafusion::physical_plan::metrics::Time; +use datafusion::{error::Result, physical_plan::metrics::Time}; use jni::objects::{GlobalRef, JObject}; pub struct Fs { diff --git a/native-engine/datafusion-ext-commons/src/io/batch_serde.rs b/native-engine/datafusion-ext-commons/src/io/batch_serde.rs index 696b733f..aa85ccb2 100644 --- a/native-engine/datafusion-ext-commons/src/io/batch_serde.rs +++ b/native-engine/datafusion-ext-commons/src/io/batch_serde.rs @@ -12,17 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::io::{read_bytes_slice, read_len, write_len}; -use arrow::array::*; -use arrow::buffer::{Buffer, MutableBuffer}; -use arrow::datatypes::*; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use std::{ + io::{BufReader, BufWriter, Read, Write}, + sync::{ + atomic::{AtomicUsize, Ordering::SeqCst}, + Arc, + }, +}; + +use arrow::{ + array::*, + buffer::{Buffer, MutableBuffer}, + datatypes::*, + record_batch::{RecordBatch, RecordBatchOptions}, +}; use bitvec::prelude::BitVec; use datafusion::common::{DataFusionError, Result}; -use std::io::{BufReader, BufWriter, Read, Write}; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::Arc; + +use crate::io::{read_bytes_slice, read_len, write_len}; pub fn write_batch( batch: &RecordBatch, @@ -182,7 +189,7 @@ pub fn write_array(array: &dyn Array, output: &mut W) -> Result<()> { DataType::UInt64 => write_primitive!(UInt64), DataType::Float32 => write_primitive!(Float32), DataType::Float64 => write_primitive!(Float64), - DataType::Decimal128(_, _) => write_primitive!(Decimal128), + DataType::Decimal128(..) => write_primitive!(Decimal128), DataType::Utf8 => write_bytes_array(as_string_array(array), output)?, DataType::Binary => write_bytes_array(as_generic_binary_array::(array), output)?, DataType::Date32 => write_primitive!(Date32), @@ -192,7 +199,7 @@ pub fn write_array(array: &dyn Array, output: &mut W) -> Result<()> { DataType::Timestamp(TimeUnit::Microsecond, _) => write_primitive!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => write_primitive!(TimestampNanosecond), DataType::List(_field) => write_list_array(as_list_array(array), output)?, - DataType::Map(_, _) => write_map_array(as_map_array(array), output)?, + DataType::Map(..) => write_map_array(as_map_array(array), output)?, DataType::Struct(_) => write_struct_array(as_struct_array(array), output)?, other => { return Err(DataFusionError::NotImplemented(format!( @@ -668,14 +675,15 @@ fn read_bytes_array( #[cfg(test)] mod test { - use crate::io::batch_serde::{read_batch, write_batch}; - use crate::io::name_batch; - use arrow::array::*; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; + use std::{io::Cursor, sync::Arc}; + + use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; use datafusion::assert_batches_eq; - use std::io::Cursor; - use std::sync::Arc; + + use crate::io::{ + batch_serde::{read_batch, write_batch}, + name_batch, + }; #[test] fn test_write_and_read_batch() { diff --git a/native-engine/datafusion-ext-commons/src/io/mod.rs b/native-engine/datafusion-ext-commons/src/io/mod.rs index dec64bb2..fb528e7e 100644 --- a/native-engine/datafusion-ext-commons/src/io/mod.rs +++ b/native-engine/datafusion-ext-commons/src/io/mod.rs @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::StructArray; - use std::io::{Read, Seek, SeekFrom, Write}; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::record_batch::RecordBatch; +use arrow::{ + array::StructArray, + datatypes::{DataType, SchemaRef}, + record_batch::RecordBatch, +}; pub use batch_serde::{read_array, read_data_type, write_array, write_data_type}; -use datafusion::common::cast::as_struct_array; -use datafusion::common::Result; +use datafusion::common::{cast::as_struct_array, Result}; mod batch_serde; diff --git a/native-engine/datafusion-ext-commons/src/lib.rs b/native-engine/datafusion-ext-commons/src/lib.rs index ba8a0d8f..14ac87f0 100644 --- a/native-engine/datafusion-ext-commons/src/lib.rs +++ b/native-engine/datafusion-ext-commons/src/lib.rs @@ -1,5 +1,3 @@ -#![feature(new_uninit)] -#![feature(io_error_other)] // Copyright 2022 The Blaze Authors // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,18 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::compute::concat; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +#![feature(new_uninit)] +#![feature(io_error_other)] +#![feature(slice_swap_unchecked)] + +use arrow::{ + compute::concat, + datatypes::SchemaRef, + error::Result as ArrowResult, + record_batch::{RecordBatch, RecordBatchOptions}, +}; use log::trace; pub mod array_builder; +pub mod bytes_arena; pub mod cast; -pub mod ffi; pub mod hadoop_fs; pub mod io; pub mod loser_tree; +pub mod rdxsort; +pub mod slim_bytes; pub mod spark_hash; pub mod streams; pub mod uda; diff --git a/native-engine/datafusion-ext-plans/src/common/rdxsort.rs b/native-engine/datafusion-ext-commons/src/rdxsort.rs similarity index 98% rename from native-engine/datafusion-ext-plans/src/common/rdxsort.rs rename to native-engine/datafusion-ext-commons/src/rdxsort.rs index 663e9888..9400a0a9 100644 --- a/native-engine/datafusion-ext-plans/src/common/rdxsort.rs +++ b/native-engine/datafusion-ext-commons/src/rdxsort.rs @@ -71,9 +71,10 @@ pub fn radix_sort_u16_by(array: &mut [T], key: impl Fn(&T) -> u16) -> Vec); diff --git a/native-engine/datafusion-ext-commons/src/spark_hash.rs b/native-engine/datafusion-ext-commons/src/spark_hash.rs index b6ce9f5d..014c3241 100644 --- a/native-engine/datafusion-ext-commons/src/spark_hash.rs +++ b/native-engine/datafusion-ext-commons/src/spark_hash.rs @@ -16,10 +16,12 @@ use std::sync::Arc; -use arrow::array::*; -use arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, - TimeUnit, +use arrow::{ + array::*, + datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, TimeUnit, + }, }; use datafusion::error::{DataFusionError, Result}; @@ -92,12 +94,14 @@ fn test_murmur3() { .into_iter() .map(|s| spark_compatible_murmur3_hash(s.as_bytes(), 42) as i32) .collect::>(); - let _expected = vec![142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928]; + let _expected = vec![ + 142593372, 1485273170, -97053317, 1322437556, -396302900, 814637928, + ]; assert_eq!(_hashes, _expected) } macro_rules! hash_array { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type:ident, $column:ident, $hashes:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { for (i, hash) in $hashes.iter_mut().enumerate() { @@ -113,25 +117,8 @@ macro_rules! hash_array { }; } -macro_rules! hash_list { - ($array_type:ident, $column: ident, $hash: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - if array.null_count() == 0 { - for i in 0..array.len() { - *$hash = spark_compatible_murmur3_hash(&array.value(i), *$hash); - } - } else { - for i in 0..array.len() { - if !array.is_null(i) { - *$hash = spark_compatible_murmur3_hash(&array.value(i), *$hash); - } - } - } - }; -} - macro_rules! hash_array_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hashes: ident) => { + ($array_type:ident, $column:ident, $ty:ident, $hashes:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); let values = array.values(); @@ -149,26 +136,8 @@ macro_rules! hash_array_primitive { }; } -macro_rules! hash_list_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hash: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); - if array.null_count() == 0 { - for value in values.iter() { - *$hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *$hash); - } - } else { - for (i, value) in values.iter().enumerate() { - if !array.is_null(i) { - *$hash = spark_compatible_murmur3_hash((*value as $ty).to_le_bytes(), *$hash); - } - } - } - }; -} - macro_rules! hash_array_decimal { - ($array_type:ident, $column: ident, $hashes: ident) => { + ($array_type:ident, $column:ident, $hashes:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { @@ -186,7 +155,7 @@ macro_rules! hash_array_decimal { } macro_rules! hash_list_decimal { - ($array_type:ident, $column: ident, $hash: ident) => { + ($array_type:ident, $column:ident, $hash:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); if array.null_count() == 0 { @@ -237,225 +206,118 @@ fn create_hashes_dictionary( /// /// The number of rows to hash is determined by `hashes_buffer.len()`. /// `hashes_buffer` should be pre-sized appropriately -pub fn create_hashes<'a>( - arrays: &[ArrayRef], - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { +pub fn create_hashes<'a>(arrays: &[ArrayRef], hashes_buffer: &mut [u32]) -> Result<()> { for col in arrays { - match col.data_type() { - DataType::Null => {} - DataType::Boolean => { - let array = col.as_any().downcast_ref::().unwrap(); - if array.null_count() == 0 { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { + hash_array(col, hashes_buffer)?; + } + Ok(()) +} + +fn hash_array(array: &ArrayRef, hashes_buffer: &mut [u32]) -> Result<()> { + match array.data_type() { + DataType::Null => {} + DataType::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + if array.null_count() == 0 { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + *hash = spark_compatible_murmur3_hash( + (if array.value(i) { 1u32 } else { 0u32 }).to_le_bytes(), + *hash, + ); + } + } else { + for (i, hash) in hashes_buffer.iter_mut().enumerate() { + if !array.is_null(i) { *hash = spark_compatible_murmur3_hash( (if array.value(i) { 1u32 } else { 0u32 }).to_le_bytes(), *hash, ); } - } else { - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = spark_compatible_murmur3_hash( - (if array.value(i) { 1u32 } else { 0u32 }).to_le_bytes(), - *hash, - ); - } - } } } + } + DataType::Int8 => { + hash_array_primitive!(Int8Array, array, i32, hashes_buffer); + } + DataType::Int16 => { + hash_array_primitive!(Int16Array, array, i32, hashes_buffer); + } + DataType::Int32 => { + hash_array_primitive!(Int32Array, array, i32, hashes_buffer); + } + DataType::Int64 => { + hash_array_primitive!(Int64Array, array, i64, hashes_buffer); + } + DataType::Float32 => { + hash_array_primitive!(Float32Array, array, f32, hashes_buffer); + } + DataType::Float64 => { + hash_array_primitive!(Float64Array, array, f64, hashes_buffer); + } + DataType::Timestamp(TimeUnit::Second, _) => { + hash_array_primitive!(TimestampSecondArray, array, i64, hashes_buffer); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + hash_array_primitive!(TimestampMillisecondArray, array, i64, hashes_buffer); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + hash_array_primitive!(TimestampMicrosecondArray, array, i64, hashes_buffer); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + hash_array_primitive!(TimestampNanosecondArray, array, i64, hashes_buffer); + } + DataType::Date32 => { + hash_array_primitive!(Date32Array, array, i32, hashes_buffer); + } + DataType::Date64 => { + hash_array_primitive!(Date64Array, array, i64, hashes_buffer); + } + DataType::Binary => { + hash_array!(BinaryArray, array, hashes_buffer); + } + DataType::LargeBinary => { + hash_array!(LargeBinaryArray, array, hashes_buffer); + } + DataType::Utf8 => { + hash_array!(StringArray, array, hashes_buffer); + } + DataType::LargeUtf8 => { + hash_array!(LargeStringArray, array, hashes_buffer); + } + DataType::Decimal128(..) => { + hash_array_decimal!(Decimal128Array, array, hashes_buffer); + } + DataType::Dictionary(index_type, _) => match **index_type { DataType::Int8 => { - hash_array_primitive!(Int8Array, col, i32, hashes_buffer); + create_hashes_dictionary::(array, hashes_buffer)?; } DataType::Int16 => { - hash_array_primitive!(Int16Array, col, i32, hashes_buffer); + create_hashes_dictionary::(array, hashes_buffer)?; } DataType::Int32 => { - hash_array_primitive!(Int32Array, col, i32, hashes_buffer); + create_hashes_dictionary::(array, hashes_buffer)?; } DataType::Int64 => { - hash_array_primitive!(Int64Array, col, i64, hashes_buffer); - } - DataType::Float32 => { - hash_array_primitive!(Float32Array, col, f32, hashes_buffer); - } - DataType::Float64 => { - hash_array_primitive!(Float64Array, col, f64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Second, _) => { - hash_array_primitive!(TimestampSecondArray, col, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - hash_array_primitive!(TimestampMillisecondArray, col, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - hash_array_primitive!(TimestampMicrosecondArray, col, i64, hashes_buffer); - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - hash_array_primitive!(TimestampNanosecondArray, col, i64, hashes_buffer); - } - DataType::Date32 => { - hash_array_primitive!(Date32Array, col, i32, hashes_buffer); - } - DataType::Date64 => { - hash_array_primitive!(Date64Array, col, i64, hashes_buffer); - } - DataType::Binary => { - hash_array!(BinaryArray, col, hashes_buffer); - } - DataType::LargeBinary => { - hash_array!(LargeBinaryArray, col, hashes_buffer); - } - DataType::Utf8 => { - hash_array!(StringArray, col, hashes_buffer); - } - DataType::LargeUtf8 => { - hash_array!(LargeStringArray, col, hashes_buffer); - } - DataType::Decimal128(_, _) => { - hash_array_decimal!(Decimal128Array, col, hashes_buffer); - } - DataType::Dictionary(index_type, _) => match **index_type { - DataType::Int8 => { - create_hashes_dictionary::(col, hashes_buffer)?; - } - DataType::Int16 => { - create_hashes_dictionary::(col, hashes_buffer)?; - } - DataType::Int32 => { - create_hashes_dictionary::(col, hashes_buffer)?; - } - DataType::Int64 => { - create_hashes_dictionary::(col, hashes_buffer)?; - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported dictionary type in hasher hashing: {}", - col.data_type(), - ))) - } - }, - DataType::List(field) => { - let list_array = col.as_any().downcast_ref::().unwrap(); - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - let sub_array = &list_array.value(i); - match field.data_type() { - DataType::Boolean => { - let array = sub_array.as_any().downcast_ref::().unwrap(); - if array.null_count() == 0 { - for index in 0..array.len() { - *hash = spark_compatible_murmur3_hash( - (if array.value(index) { 1u32 } else { 0u32 }) - .to_le_bytes(), - *hash, - ); - } - } else { - for index in 0..array.len() { - if !array.is_null(index) { - *hash = spark_compatible_murmur3_hash( - (if array.value(index) { 1u32 } else { 0u32 }) - .to_le_bytes(), - *hash, - ); - } - } - } - } - DataType::Int8 => { - hash_list_primitive!(Int8Array, sub_array, i32, hash); - } - DataType::Int16 => { - hash_list_primitive!(Int16Array, sub_array, i32, hash); - } - DataType::Int32 => { - hash_list_primitive!(Int32Array, sub_array, i32, hash); - } - DataType::Int64 => { - hash_list_primitive!(Int64Array, sub_array, i64, hash); - } - DataType::Float32 => { - hash_list_primitive!(Float32Array, sub_array, f32, hash); - } - DataType::Float64 => { - hash_list_primitive!(Float64Array, sub_array, f64, hash); - } - DataType::Timestamp(TimeUnit::Second, _) => { - hash_list_primitive!(TimestampSecondArray, sub_array, i64, hash); - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - hash_list_primitive!(TimestampMillisecondArray, sub_array, i64, hash); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - hash_list_primitive!(TimestampMicrosecondArray, sub_array, i64, hash); - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - hash_list_primitive!(TimestampNanosecondArray, sub_array, i64, hash); - } - DataType::Date32 => { - hash_list_primitive!(Date32Array, sub_array, i32, hash); - } - DataType::Date64 => { - hash_list_primitive!(Date64Array, sub_array, i64, hash); - } - DataType::Binary => { - hash_list!(BinaryArray, sub_array, hash); - } - DataType::LargeBinary => { - hash_list!(LargeBinaryArray, sub_array, hash); - } - DataType::Utf8 => { - hash_list!(StringArray, sub_array, hash); - } - DataType::LargeUtf8 => { - hash_list!(LargeStringArray, sub_array, hash); - } - DataType::Decimal128(_, _) => { - hash_list_decimal!(Decimal128Array, sub_array, hash); - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported list data type in hasher: {}", - field.data_type() - ))); - } - } - } - } - DataType::Map(_, _) => { - let map_array = col.as_any().downcast_ref::().unwrap(); - let key_array = map_array.keys(); - let value_array = map_array.values(); - let offsets_buffer = map_array.value_offsets(); - let mut cur_offset = 0; - for (&next_offset, hash) in - offsets_buffer.iter().skip(1).zip(hashes_buffer.iter_mut()) - { - for idx in cur_offset..next_offset { - update_map_hashes(key_array, idx, hash)?; - update_map_hashes(value_array, idx, hash)?; - } - cur_offset = next_offset; - } - } - DataType::Struct(_) => { - let struct_array = col.as_any().downcast_ref::().unwrap(); - create_hashes(struct_array.columns(), hashes_buffer)?; + create_hashes_dictionary::(array, hashes_buffer)?; } _ => { - // This is internal because we should have caught this before. return Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {}", - col.data_type() - ))); + "Unsupported dictionary type in hasher hashing: {}", + array.data_type(), + ))) + } + }, + _ => { + for idx in 0..array.len() { + hash_one(array, idx, &mut hashes_buffer[idx])?; } } } - Ok(hashes_buffer) + Ok(()) } -macro_rules! hash_map_primitive { - ($array_type:ident, $column: ident, $ty: ident, $hash: ident, $idx: ident) => { +macro_rules! hash_one_primitive { + ($array_type:ident, $column:ident, $ty:ident, $hash:ident, $idx:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); *$hash = spark_compatible_murmur3_hash( (array.value($idx as usize) as $ty).to_le_bytes(), @@ -464,90 +326,110 @@ macro_rules! hash_map_primitive { }; } -macro_rules! hash_map_binary { - ($array_type:ident, $column: ident, $hash: ident, $idx: ident) => { +macro_rules! hash_one_binary { + ($array_type:ident, $column:ident, $hash:ident, $idx:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); *$hash = spark_compatible_murmur3_hash(&array.value($idx as usize), *$hash); }; } -macro_rules! hash_map_decimal { - ($array_type:ident, $column: ident, $hash: ident, $idx: ident) => { +macro_rules! hash_one_decimal { + ($array_type:ident, $column:ident, $hash:ident, $idx:ident) => { let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); *$hash = spark_compatible_murmur3_hash(array.value($idx as usize).to_le_bytes(), *$hash); }; } -fn update_map_hashes(array: &ArrayRef, idx: i32, hash: &mut u32) -> Result<()> { - if array.is_valid(idx as usize) { - match array.data_type() { +fn hash_one(col: &ArrayRef, idx: usize, hash: &mut u32) -> Result<()> { + if col.is_valid(idx) { + match col.data_type() { + DataType::Null => {} DataType::Boolean => { - let array = array.as_any().downcast_ref::().unwrap(); + let array = col.as_any().downcast_ref::().unwrap(); *hash = spark_compatible_murmur3_hash( - (if array.value(idx as usize) { - 1u32 - } else { - 0u32 - }) - .to_le_bytes(), + (if array.value(idx) { 1u32 } else { 0u32 }).to_le_bytes(), *hash, ); } DataType::Int8 => { - hash_map_primitive!(Int8Array, array, i32, hash, idx); + hash_one_primitive!(Int8Array, col, i32, hash, idx); } DataType::Int16 => { - hash_map_primitive!(Int16Array, array, i32, hash, idx); + hash_one_primitive!(Int16Array, col, i32, hash, idx); } DataType::Int32 => { - hash_map_primitive!(Int32Array, array, i32, hash, idx); + hash_one_primitive!(Int32Array, col, i32, hash, idx); } DataType::Int64 => { - hash_map_primitive!(Int64Array, array, i64, hash, idx); + hash_one_primitive!(Int64Array, col, i64, hash, idx); } DataType::Float32 => { - hash_map_primitive!(Float32Array, array, f32, hash, idx); + hash_one_primitive!(Float32Array, col, f32, hash, idx); } DataType::Float64 => { - hash_map_primitive!(Float64Array, array, f64, hash, idx); + hash_one_primitive!(Float64Array, col, f64, hash, idx); } DataType::Timestamp(TimeUnit::Second, None) => { - hash_map_primitive!(TimestampSecondArray, array, i64, hash, idx); + hash_one_primitive!(TimestampSecondArray, col, i64, hash, idx); } DataType::Timestamp(TimeUnit::Millisecond, None) => { - hash_map_primitive!(TimestampMillisecondArray, array, i64, hash, idx); + hash_one_primitive!(TimestampMillisecondArray, col, i64, hash, idx); } DataType::Timestamp(TimeUnit::Microsecond, None) => { - hash_map_primitive!(TimestampMicrosecondArray, array, i64, hash, idx); + hash_one_primitive!(TimestampMicrosecondArray, col, i64, hash, idx); } DataType::Timestamp(TimeUnit::Nanosecond, _) => { - hash_map_primitive!(TimestampNanosecondArray, array, i64, hash, idx); + hash_one_primitive!(TimestampNanosecondArray, col, i64, hash, idx); } DataType::Date32 => { - hash_map_primitive!(Date32Array, array, i32, hash, idx); + hash_one_primitive!(Date32Array, col, i32, hash, idx); } DataType::Date64 => { - hash_map_primitive!(Date64Array, array, i64, hash, idx); + hash_one_primitive!(Date64Array, col, i64, hash, idx); } DataType::Binary => { - hash_map_binary!(BinaryArray, array, hash, idx); + hash_one_binary!(BinaryArray, col, hash, idx); } DataType::LargeBinary => { - hash_map_binary!(LargeBinaryArray, array, hash, idx); + hash_one_binary!(LargeBinaryArray, col, hash, idx); } DataType::Utf8 => { - hash_map_binary!(StringArray, array, hash, idx); + hash_one_binary!(StringArray, col, hash, idx); } DataType::LargeUtf8 => { - hash_map_binary!(LargeStringArray, array, hash, idx); + hash_one_binary!(LargeStringArray, col, hash, idx); } - DataType::Decimal128(_, _) => { - hash_map_decimal!(Decimal128Array, array, hash, idx); + DataType::Decimal128(..) => { + hash_one_decimal!(Decimal128Array, col, hash, idx); + } + DataType::List(..) => { + let list_array = col.as_any().downcast_ref::().unwrap(); + let value_array = list_array.value(idx); + for i in 0..value_array.len() { + hash_one(&value_array, i, hash)?; + } + } + DataType::Map(..) => { + let map_array = col.as_any().downcast_ref::().unwrap(); + let kv_array = map_array.value(idx); + let key_array = kv_array.column(0); + let value_array = kv_array.column(1); + for i in 0..kv_array.len() { + hash_one(key_array, i, hash)?; + hash_one(value_array, i, hash)?; + } + } + DataType::Struct(_) => { + let struct_array = col.as_any().downcast_ref::().unwrap(); + for col in struct_array.columns() { + hash_one(col, idx, hash)?; + } } _ => { + // This is internal because we should have caught this before. return Err(DataFusionError::Internal(format!( - "Unsupported map key/value data type in hasher: {}", - array.data_type() + "Unsupported data type in hasher: {}", + col.data_type() ))); } } @@ -567,13 +449,16 @@ pub fn pmod(hash: u32, n: usize) -> usize { mod tests { use std::sync::Arc; - use crate::spark_hash::{create_hashes, pmod, spark_compatible_murmur3_hash}; - use arrow::array::{ - make_array, Array, ArrayData, ArrayRef, Int32Array, Int64Array, Int8Array, MapArray, - StringArray, StructArray, UInt32Array, + use arrow::{ + array::{ + make_array, Array, ArrayData, ArrayRef, Int32Array, Int64Array, Int8Array, MapArray, + StringArray, StructArray, UInt32Array, + }, + buffer::Buffer, + datatypes::{DataType, Field, ToByteSlice}, }; - use arrow::buffer::Buffer; - use arrow::datatypes::{DataType, Field, ToByteSlice}; + + use crate::spark_hash::{create_hashes, pmod, spark_compatible_murmur3_hash}; #[test] fn test_list() { @@ -639,7 +524,8 @@ mod tests { let mut hashes = vec![42; 5]; create_hashes(&[i], &mut hashes).unwrap(); - // generated with Murmur3Hash(Seq(Literal("")), 42).eval() since Spark is tested against this as well + // generated with Murmur3Hash(Seq(Literal("")), 42).eval() since Spark is tested + // against this as well let expected = vec![3286402344, 2486176763, 142593372, 885025535, 2395000894]; assert_eq!(hashes, expected); } diff --git a/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs b/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs index 92a53393..6359bd05 100644 --- a/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs +++ b/native-engine/datafusion-ext-commons/src/streams/coalesce_stream.rs @@ -12,18 +12,65 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::concat_batches; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::physical_plan::metrics::Time; -use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use std::{ + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion::{ + common::Result, + execution::TaskContext, + physical_plan::{ + metrics::{BaselineMetrics, Time}, + RecordBatchStream, SendableRecordBatchStream, + }, +}; use futures::{Stream, StreamExt}; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; + +use crate::concat_batches; const STAGING_BATCHES_MEM_SIZE_LIMIT: usize = 1 << 26; // limit output batch size to 64MB +pub trait CoalesceInput { + fn coalesce_input( + &self, + input: SendableRecordBatchStream, + batch_size: usize, + metrics: &BaselineMetrics, + ) -> Result; + + fn coalesce_with_default_batch_size( + &self, + input: SendableRecordBatchStream, + metrics: &BaselineMetrics, + ) -> Result; +} + +impl CoalesceInput for Arc { + fn coalesce_input( + &self, + input: SendableRecordBatchStream, + batch_size: usize, + metrics: &BaselineMetrics, + ) -> Result { + Ok(Box::pin(CoalesceStream::new( + input, + batch_size, + metrics.elapsed_compute().clone(), + ))) + } + + fn coalesce_with_default_batch_size( + &self, + input: SendableRecordBatchStream, + metrics: &BaselineMetrics, + ) -> Result { + self.coalesce_input(input, self.session_config().batch_size(), metrics) + } +} + pub struct CoalesceStream { input: SendableRecordBatchStream, staging_batches: Vec, diff --git a/native-engine/datafusion-ext-commons/src/streams/ffi_stream.rs b/native-engine/datafusion-ext-commons/src/streams/ffi_stream.rs index d1687346..54233c88 100644 --- a/native-engine/datafusion-ext-commons/src/streams/ffi_stream.rs +++ b/native-engine/datafusion-ext-commons/src/streams/ffi_stream.rs @@ -12,19 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::StructArray; -use arrow::datatypes::SchemaRef; -use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; -use arrow::record_batch::RecordBatch; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use arrow::{ + array::StructArray, + datatypes::SchemaRef, + ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + record_batch::RecordBatch, +}; use blaze_jni_bridge::{jni_call, jni_new_object}; -use datafusion::error::Result; -use datafusion::physical_plan::metrics::{BaselineMetrics, Count}; -use datafusion::physical_plan::RecordBatchStream; +use datafusion::{ + error::Result, + physical_plan::{ + metrics::{BaselineMetrics, Count}, + RecordBatchStream, + }, +}; use futures::Stream; -use jni::objects::{GlobalRef, JObject}; -use jni::sys::{jboolean, JNI_TRUE}; -use std::pin::Pin; -use std::task::{Context, Poll}; +use jni::{ + objects::{GlobalRef, JObject}, + sys::{jboolean, JNI_TRUE}, +}; pub struct FFIReaderStream { schema: SchemaRef, diff --git a/native-engine/datafusion-ext-commons/src/streams/ipc_stream.rs b/native-engine/datafusion-ext-commons/src/streams/ipc_stream.rs index 4343e28b..a65c1d85 100644 --- a/native-engine/datafusion-ext-commons/src/streams/ipc_stream.rs +++ b/native-engine/datafusion-ext-commons/src/streams/ipc_stream.rs @@ -12,26 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::fmt::Debug; +use std::{ + fmt::Debug, + fs::File, + io::{BufReader, Error as IoError, Read, Seek, SeekFrom}, + pin::Pin, + task::{Context, Poll}, +}; -use crate::io::read_one_batch; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use blaze_jni_bridge::{ jni_call, jni_get_object_class, jni_get_string, jni_new_direct_byte_buffer, jni_new_global_ref, }; -use datafusion::error::Result; -use datafusion::physical_plan::metrics::{BaselineMetrics, Count}; -use datafusion::physical_plan::RecordBatchStream; +use datafusion::{ + error::Result, + physical_plan::{ + metrics::{BaselineMetrics, Count}, + RecordBatchStream, + }, +}; use futures::Stream; -use jni::objects::{GlobalRef, JObject}; -use jni::sys::{jboolean, jint, jlong, JNI_TRUE}; -use std::fs::File; -use std::io::{BufReader, Read, SeekFrom}; -use std::io::{Error as IoError, Seek}; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; +use jni::{ + objects::{GlobalRef, JObject}, + sys::{jboolean, jint, jlong, JNI_TRUE}, +}; + +use crate::io::read_one_batch; #[derive(Debug, Clone, Copy)] pub enum IpcReadMode { diff --git a/native-engine/datafusion-ext-commons/src/uda.rs b/native-engine/datafusion-ext-commons/src/uda.rs index 6d3f912b..d66e5d2b 100644 --- a/native-engine/datafusion-ext-commons/src/uda.rs +++ b/native-engine/datafusion-ext-commons/src/uda.rs @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::{Array, ArrayData, ArrayRef, BooleanArray}; -use arrow::buffer::NullBuffer; -use arrow::datatypes::DataType; -use arrow::error::Result; -use std::any::Any; -use std::fmt::Debug; -use std::slice::Iter; -use std::sync::Arc; +use std::{any::Any, fmt::Debug, slice::Iter, sync::Arc}; + +use arrow::{ + array::{Array, ArrayData, ArrayRef, BooleanArray}, + buffer::NullBuffer, + datatypes::DataType, + error::Result, +}; #[derive(Debug, Clone)] pub struct UserDefinedArray { diff --git a/native-engine/datafusion-ext-exprs/src/cast.rs b/native-engine/datafusion-ext-exprs/src/cast.rs index 81775049..0e946c7c 100644 --- a/native-engine/datafusion-ext-exprs/src/cast.rs +++ b/native-engine/datafusion-ext-exprs/src/cast.rs @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{datatypes::*, record_batch::RecordBatch}; +use datafusion::{ + common::Result, logical_expr::ColumnarValue, physical_expr::PhysicalExpr, scalar::ScalarValue, +}; + use crate::down_cast_any_ref; -use arrow::datatypes::*; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::PhysicalExpr; -use datafusion::scalar::ScalarValue; -use std::any::Any; -use std::fmt::{Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; /// cast expression compatible with spark #[derive(Debug, Hash)] @@ -101,18 +103,21 @@ impl PhysicalExpr for TryCastExpr { } #[cfg(test)] mod test { - use crate::cast::TryCastExpr; - use arrow::array::{ArrayRef, Float32Array, Int32Array, StringArray}; + use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; + use arrow::{ + array::{ArrayRef, Float32Array, Int32Array, StringArray}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr}; - use std::sync::Arc; + + use crate::cast::TryCastExpr; #[test] fn test_ok_1() { - //input: Array - //cast Float32 into Int32 + // input: Array + // cast Float32 into Int32 let float_arr: ArrayRef = Arc::new(Float32Array::from(vec![ Some(7.6), Some(9.0), @@ -156,8 +161,8 @@ mod test { #[test] fn test_ok_2() { - //input: Array - //cast Utf8 into Float32 + // input: Array + // cast Utf8 into Float32 let string_arr: ArrayRef = Arc::new(StringArray::from(vec![ Some("123"), Some("321.9"), @@ -195,8 +200,8 @@ mod test { #[test] fn test_ok_3() { - //input: Scalar - //cast Utf8 into Float32 + // input: Scalar + // cast Utf8 into Float32 let string_arr: ArrayRef = Arc::new(StringArray::from(vec![ Some("123"), Some("321.9"), diff --git a/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs b/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs index f77ac882..b859c29d 100644 --- a/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs +++ b/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs @@ -12,20 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::*; -use arrow::compute::*; -use arrow::datatypes::*; -use arrow::record_batch::RecordBatch; -use datafusion::common::cast::{as_list_array, as_struct_array}; -use datafusion::common::DataFusionError; -use datafusion::common::Result; -use datafusion::common::ScalarValue; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::PhysicalExpr; -use std::convert::TryInto; -use std::fmt::Debug; -use std::hash::{Hash, Hasher}; -use std::{any::Any, sync::Arc}; +use std::{ + any::Any, + convert::TryInto, + fmt::Debug, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{array::*, compute::*, datatypes::*, record_batch::RecordBatch}; +use datafusion::{ + common::{ + cast::{as_list_array, as_struct_array}, + DataFusionError, Result, ScalarValue, + }, + logical_expr::ColumnarValue, + physical_expr::PhysicalExpr, +}; use crate::down_cast_any_ref; @@ -187,16 +190,17 @@ fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { #[cfg(test)] mod test { - use super::GetIndexedFieldExpr; - use arrow::array::*; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::physical_plan::expressions::Column; - use datafusion::physical_plan::PhysicalExpr; - use datafusion::scalar::ScalarValue; use std::sync::Arc; + use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; + use datafusion::{ + assert_batches_eq, + physical_plan::{expressions::Column, PhysicalExpr}, + scalar::ScalarValue, + }; + + use super::GetIndexedFieldExpr; + #[test] fn test_list() -> Result<(), Box> { let array: ArrayRef = Arc::new(ListArray::from_iter_primitive::(vec![ diff --git a/native-engine/datafusion-ext-exprs/src/get_map_value.rs b/native-engine/datafusion-ext-exprs/src/get_map_value.rs index 44100109..0ac5e29d 100644 --- a/native-engine/datafusion-ext-exprs/src/get_map_value.rs +++ b/native-engine/datafusion-ext-exprs/src/get_map_value.rs @@ -12,22 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::down_cast_any_ref; -use arrow::array::*; -use arrow::compute::{eq_dyn_binary_scalar, eq_dyn_bool_scalar, eq_dyn_scalar, eq_dyn_utf8_scalar}; -use arrow::datatypes::Field; +use std::{ + any::Any, + fmt::Debug, + hash::{Hash, Hasher}, + sync::Arc, +}; + use arrow::{ - datatypes::{DataType, Schema}, + array::*, + compute::{eq_dyn_binary_scalar, eq_dyn_bool_scalar, eq_dyn_scalar, eq_dyn_utf8_scalar}, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use datafusion::common::DataFusionError; -use datafusion::common::Result; -use datafusion::common::ScalarValue; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::PhysicalExpr; -use std::fmt::Debug; -use std::hash::{Hash, Hasher}; -use std::{any::Any, sync::Arc}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + logical_expr::ColumnarValue, + physical_expr::PhysicalExpr, +}; + +use crate::down_cast_any_ref; /// expression to get value of a key in map array. #[derive(Debug, Hash)] @@ -288,20 +292,25 @@ fn get_data_type_field(data_type: &DataType) -> Result { #[cfg(test)] mod test { - use super::GetMapValueExpr; - use arrow::array::*; - use arrow::buffer::Buffer; - use arrow::datatypes::{DataType, Field, ToByteSlice}; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::common::ScalarValue; - use datafusion::physical_plan::expressions::Column; - use datafusion::physical_plan::PhysicalExpr; use std::sync::Arc; + use arrow::{ + array::*, + buffer::Buffer, + datatypes::{DataType, Field, ToByteSlice}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_eq, + common::ScalarValue, + physical_plan::{expressions::Column, PhysicalExpr}, + }; + + use super::GetMapValueExpr; + #[test] fn test_map_1() -> Result<(), Box> { - //Construct key and values + // Construct key and values let key_data = ArrayData::builder(DataType::Int32) .len(8) .add_buffer(Buffer::from(&[0, 1, 2, 3, 4, 5, 6, 7].to_byte_slice())) @@ -342,17 +351,25 @@ mod test { .build() .unwrap(); let map_array: ArrayRef = Arc::new(MapArray::from(map_data)); - let input_batch = RecordBatch::try_from_iter_with_nullable(vec![("col", map_array, true)])?; + let input_batch = + RecordBatch::try_from_iter_with_nullable(vec![("test col", map_array, true)])?; let get_indexed = Arc::new(GetMapValueExpr::new( - Arc::new(Column::new("col", 0)), + Arc::new(Column::new("test col", 0)), ScalarValue::from(7_i32), )); let output_array = get_indexed.evaluate(&input_batch)?.into_array(0); let output_batch = - RecordBatch::try_from_iter_with_nullable(vec![("col", output_array, true)])?; + RecordBatch::try_from_iter_with_nullable(vec![("test col", output_array, true)])?; - let expected = - vec!["+-----+", "| col |", "+-----+", "| |", "| |", "| 70 |", "+-----+"]; + let expected = vec![ + "+----------+", + "| test col |", + "+----------+", + "| |", + "| |", + "| 70 |", + "+----------+", + ]; assert_batches_eq!(expected, &[output_batch]); Ok(()) } @@ -370,17 +387,25 @@ mod test { MapArray::new_from_strings(keys.clone().into_iter(), &values_data, &entry_offsets) .unwrap(), ); - let input_batch = RecordBatch::try_from_iter_with_nullable(vec![("col", map_array, true)])?; + let input_batch = + RecordBatch::try_from_iter_with_nullable(vec![("test col", map_array, true)])?; let get_indexed = Arc::new(GetMapValueExpr::new( - Arc::new(Column::new("col", 0)), + Arc::new(Column::new("test col", 0)), ScalarValue::from("e"), )); let output_array = get_indexed.evaluate(&input_batch)?.into_array(0); let output_batch = - RecordBatch::try_from_iter_with_nullable(vec![("col", output_array, true)])?; + RecordBatch::try_from_iter_with_nullable(vec![("test col", output_array, true)])?; - let expected = - vec!["+-----+", "| col |", "+-----+", "| |", "| 40 |", "| |", "+-----+"]; + let expected = vec![ + "+----------+", + "| test col |", + "+----------+", + "| |", + "| 40 |", + "| |", + "+----------+", + ]; assert_batches_eq!(expected, &[output_batch]); Ok(()) } diff --git a/native-engine/datafusion-ext-exprs/src/lib.rs b/native-engine/datafusion-ext-exprs/src/lib.rs index 1f04b1f7..8fff701d 100644 --- a/native-engine/datafusion-ext-exprs/src/lib.rs +++ b/native-engine/datafusion-ext-exprs/src/lib.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{any::Any, sync::Arc}; + use datafusion::physical_expr::PhysicalExpr; -use std::any::Any; -use std::sync::Arc; pub mod cast; pub mod get_indexed_field; diff --git a/native-engine/datafusion-ext-exprs/src/named_struct.rs b/native-engine/datafusion-ext-exprs/src/named_struct.rs index 0703342e..3c3e3ccd 100644 --- a/native-engine/datafusion-ext-exprs/src/named_struct.rs +++ b/native-engine/datafusion-ext-exprs/src/named_struct.rs @@ -12,25 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::down_cast_any_ref; - -use datafusion::arrow::array::StructArray; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; -use arrow::array::Array; -use arrow::datatypes::{Field, Fields, SchemaRef}; -use arrow::record_batch::RecordBatchOptions; -use datafusion::arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, +use arrow::{ + array::Array, + datatypes::{Field, Fields, SchemaRef}, + record_batch::RecordBatchOptions, +}; +use datafusion::{ + arrow::{ + array::StructArray, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, + }, + common::{DataFusionError, Result}, + logical_expr::ColumnarValue, + physical_expr::{expr_list_eq_any_order, PhysicalExpr}, }; -use datafusion::common::DataFusionError; -use datafusion::common::Result; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::{expr_list_eq_any_order, PhysicalExpr}; use datafusion_ext_commons::io::name_batch; -use std::fmt::{Debug, Formatter}; -use std::hash::{Hash, Hasher}; -use std::{any::Any, sync::Arc}; + +use crate::down_cast_any_ref; /// expression to get a field of from NameStruct. #[derive(Debug, Hash)] @@ -131,15 +137,16 @@ impl PhysicalExpr for NamedStructExpr { #[cfg(test)] mod test { - use crate::named_struct::NamedStructExpr; - use arrow::array::*; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::physical_plan::expressions::Column; - use datafusion::physical_plan::PhysicalExpr; use std::sync::Arc; + use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; + use datafusion::{ + assert_batches_eq, + physical_plan::{expressions::Column, PhysicalExpr}, + }; + + use crate::named_struct::NamedStructExpr; + #[test] fn test_list() -> Result<(), Box> { let array: ArrayRef = Arc::new(ListArray::from_iter_primitive::(vec![ @@ -152,7 +159,10 @@ mod test { let input_batch = RecordBatch::try_from_iter_with_nullable(vec![("cccccc1", array, true)])?; let named_struct = Arc::new(NamedStructExpr::try_new( - vec![Arc::new(Column::new("cccccc1", 0)), Arc::new(Column::new("cccccc1", 0))], + vec![ + Arc::new(Column::new("cccccc1", 0)), + Arc::new(Column::new("cccccc1", 0)), + ], DataType::Struct(Fields::from(vec![ Field::new( "field1", diff --git a/native-engine/datafusion-ext-exprs/src/spark_scalar_subquery_wrapper.rs b/native-engine/datafusion-ext-exprs/src/spark_scalar_subquery_wrapper.rs index d6fa6baf..67d1f90b 100644 --- a/native-engine/datafusion-ext-exprs/src/spark_scalar_subquery_wrapper.rs +++ b/native-engine/datafusion-ext-exprs/src/spark_scalar_subquery_wrapper.rs @@ -12,17 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::spark_udf_wrapper::SparkUDFWrapperExpr; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion::common::{Result, ScalarValue}; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_plan::PhysicalExpr; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::Hasher, + sync::Arc, +}; + +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use datafusion::{ + common::{Result, ScalarValue}, + logical_expr::ColumnarValue, + physical_plan::PhysicalExpr, +}; use once_cell::sync::OnceCell; -use std::any::Any; -use std::fmt::{Debug, Display, Formatter}; -use std::hash::Hasher; -use std::sync::Arc; + +use crate::spark_udf_wrapper::SparkUDFWrapperExpr; pub struct SparkScalarSubqueryWrapperExpr { pub serialized: Vec, diff --git a/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs b/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs index 2a60273d..57e49c74 100644 --- a/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs +++ b/native-engine/datafusion-ext-exprs/src/spark_udf_wrapper.rs @@ -12,30 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::down_cast_any_ref; -use arrow::array::{as_struct_array, make_array, Array, ArrayRef, StructArray}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::Hasher, + sync::Arc, +}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow::{ + array::{as_struct_array, make_array, Array, ArrayRef, StructArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + record_batch::{RecordBatch, RecordBatchOptions}, +}; use blaze_jni_bridge::{ - is_task_running, jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, + conf, conf::IntConf, is_task_running, jni_call, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_object, }; -use datafusion::common::DataFusionError; -use datafusion::error::Result; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::utils::expr_list_eq_any_order; -use datafusion::physical_plan::PhysicalExpr; - +use datafusion::{ + common::DataFusionError, error::Result, logical_expr::ColumnarValue, + physical_expr::utils::expr_list_eq_any_order, physical_plan::PhysicalExpr, +}; use jni::objects::GlobalRef; use once_cell::sync::OnceCell; -use std::any::Any; -use std::fmt::{Debug, Display, Formatter}; -use std::hash::Hasher; - -use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; -use std::sync::Arc; +use crate::down_cast_any_ref; pub struct SparkUDFWrapperExpr { pub serialized: Vec, @@ -69,7 +70,7 @@ impl SparkUDFWrapperExpr { return_nullable: bool, params: Vec>, ) -> Result { - let num_threads = jni_call_static!(BlazeConf.udfWrapperNumThreads() -> i32)? as usize; + let num_threads = conf::UDF_WRAPPER_NUM_THREADS.value()? as usize; Ok(Self { serialized, return_type: return_type.clone(), diff --git a/native-engine/datafusion-ext-exprs/src/string_contains.rs b/native-engine/datafusion-ext-exprs/src/string_contains.rs index 726aedc7..1ee33bb7 100644 --- a/native-engine/datafusion-ext-exprs/src/string_contains.rs +++ b/native-engine/datafusion-ext-exprs/src/string_contains.rs @@ -12,17 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + array::{Array, BooleanArray, StringArray}, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + logical_expr::ColumnarValue, + physical_plan::PhysicalExpr, +}; + use crate::down_cast_any_ref; -use arrow::array::{Array, BooleanArray, StringArray}; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, Result, ScalarValue}; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_plan::PhysicalExpr; -use std::any::Any; -use std::fmt::{Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; #[derive(Debug, Hash)] pub struct StringContainsExpr { @@ -114,13 +122,17 @@ impl PhysicalExpr for StringContainsExpr { #[cfg(test)] mod test { - use crate::string_contains::StringContainsExpr; - use arrow::array::{ArrayRef, BooleanArray, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr}; use std::sync::Arc; + use arrow::{ + array::{ArrayRef, BooleanArray, StringArray}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr}; + + use crate::string_contains::StringContainsExpr; + #[test] fn test_ok() { // create a StringArray from the vector diff --git a/native-engine/datafusion-ext-exprs/src/string_ends_with.rs b/native-engine/datafusion-ext-exprs/src/string_ends_with.rs index 3e8a3fa2..0dc7bd5c 100644 --- a/native-engine/datafusion-ext-exprs/src/string_ends_with.rs +++ b/native-engine/datafusion-ext-exprs/src/string_ends_with.rs @@ -12,17 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + array::{Array, BooleanArray, StringArray}, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + logical_expr::ColumnarValue, + physical_plan::PhysicalExpr, +}; + use crate::down_cast_any_ref; -use arrow::array::{Array, BooleanArray, StringArray}; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, Result, ScalarValue}; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_plan::PhysicalExpr; -use std::any::Any; -use std::fmt::{Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; #[derive(Debug, Hash)] pub struct StringEndsWithExpr { @@ -117,14 +125,16 @@ impl PhysicalExpr for StringEndsWithExpr { #[cfg(test)] mod test { - use crate::string_ends_with::StringEndsWithExpr; - use arrow::array::{ArrayRef, BooleanArray, StringArray}; + use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; + use arrow::{ + array::{ArrayRef, BooleanArray, StringArray}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr}; - use std::sync::Arc; + use crate::string_ends_with::StringEndsWithExpr; #[test] fn test_array() { @@ -135,14 +145,14 @@ mod test { Some("rr".to_string()), Some("roser r".to_string()), ])); - //create a shema with the field + // create a shema with the field let schema = Arc::new(Schema::new(vec![Field::new("col2", DataType::Utf8, true)])); - //create a RecordBatch with the shema and StringArray + // create a RecordBatch with the shema and StringArray let batch = RecordBatch::try_new(schema, vec![string_array]).expect("Error creating RecordBatch"); - //test: col2 like '%rr' + // test: col2 like '%rr' let pattern = "rr".to_string(); let expr = Arc::new(StringEndsWithExpr::new( phys_expr::col("col2", &batch.schema()).unwrap(), @@ -166,7 +176,7 @@ mod test { #[test] fn test_scalar_string() { - //create a StringArray from the vector + // create a StringArray from the vector let string_array: ArrayRef = Arc::new(StringArray::from(vec![ Some("Hello, Rust".to_string()), Some("Hello, He".to_string()), @@ -174,14 +184,14 @@ mod test { Some("RustHe".to_string()), Some("HellHe".to_string()), ])); - //create a schema with the field + // create a schema with the field let schema = Arc::new(Schema::new(vec![Field::new("col3", DataType::Utf8, true)])); - //create a RecordBatch with the schema and StringArray + // create a RecordBatch with the schema and StringArray let batch = RecordBatch::try_new(schema, vec![string_array]).expect("Error creating RecordBatch"); - //test: col3 like "%He" + // test: col3 like "%He" let pattern = "He".to_string(); // select "Hello, Rust" like "%He" from batch let expr = Arc::new(StringEndsWithExpr::new( @@ -193,7 +203,7 @@ mod test { .expect("Error evaluating expr") .into_array(batch.num_rows()); - //verify result + // verify result let expected: ArrayRef = Arc::new(BooleanArray::from(vec![ Some(false), Some(false), diff --git a/native-engine/datafusion-ext-exprs/src/string_starts_with.rs b/native-engine/datafusion-ext-exprs/src/string_starts_with.rs index 7e65eb26..d46b37a7 100644 --- a/native-engine/datafusion-ext-exprs/src/string_starts_with.rs +++ b/native-engine/datafusion-ext-exprs/src/string_starts_with.rs @@ -12,17 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + array::{Array, BooleanArray, StringArray}, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + logical_expr::ColumnarValue, + physical_plan::PhysicalExpr, +}; + use crate::down_cast_any_ref; -use arrow::array::{Array, BooleanArray, StringArray}; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, Result, ScalarValue}; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_plan::PhysicalExpr; -use std::any::Any; -use std::fmt::{Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; #[derive(Debug, Hash)] pub struct StringStartsWithExpr { @@ -116,12 +124,15 @@ impl PhysicalExpr for StringStartsWithExpr { #[cfg(test)] mod test { - use arrow::array::{ArrayRef, BooleanArray, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr}; use std::sync::Arc; + use arrow::{ + array::{ArrayRef, BooleanArray, StringArray}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::physical_expr::{expressions as phys_expr, PhysicalExpr}; + use crate::string_starts_with::StringStartsWithExpr; #[test] diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index e420c89c..b3ad80b7 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use datafusion::common::{DataFusionError, Result}; -use datafusion::logical_expr::ScalarFunctionImplementation; use std::sync::Arc; +use datafusion::{ + common::{DataFusionError, Result}, + logical_expr::ScalarFunctionImplementation, +}; + mod spark_check_overflow; mod spark_get_json_object; mod spark_make_array; diff --git a/native-engine/datafusion-ext-functions/src/spark_check_overflow.rs b/native-engine/datafusion-ext-functions/src/spark_check_overflow.rs index 9d9a28b1..f3c3fb2c 100644 --- a/native-engine/datafusion-ext-functions/src/spark_check_overflow.rs +++ b/native-engine/datafusion-ext-functions/src/spark_check_overflow.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{cmp::Ordering, sync::Arc}; + use arrow::array::*; -use datafusion::common::Result; -use datafusion::common::ScalarValue; -use datafusion::physical_plan::ColumnarValue; -use std::cmp::Ordering; -use std::sync::Arc; +use datafusion::{ + common::{Result, ScalarValue}, + physical_plan::ColumnarValue, +}; /// implements org.apache.spark.sql.catalyst.expressions.CheckOverflow pub fn spark_check_overflow(args: &[ColumnarValue]) -> Result { @@ -105,8 +106,8 @@ fn change_precision_round_half_up( } } Ordering::Greater => { - // We might be able to multiply i128_val by a power of 10 and not overflow, but if not, - // switch to using a BigDecimal + // We might be able to multiply i128_val by a power of 10 and not overflow, but + // if not, switch to using a BigDecimal let diff = to_scale - scale; // Multiplying i128_val by POW_10(diff) will still keep it below max_long_digits i128_val *= i128::pow(10, diff as u32); @@ -123,12 +124,13 @@ fn change_precision_round_half_up( } #[cfg(test)] mod test { - use crate::spark_check_overflow::spark_check_overflow; - use arrow::array::{ArrayRef, Decimal128Array}; - use datafusion::common::ScalarValue; - use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; + use arrow::array::{ArrayRef, Decimal128Array}; + use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; + + use crate::spark_check_overflow::spark_check_overflow; + #[test] fn test_check_overflow() { let array = Decimal128Array::from(vec![ @@ -143,8 +145,8 @@ mod test { let result = spark_check_overflow(&vec![ ColumnarValue::Array(Arc::new(array)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(10))), //precision - ColumnarValue::Scalar(ScalarValue::Int32(Some(5))), //scale + ColumnarValue::Scalar(ScalarValue::Int32(Some(10))), // precision + ColumnarValue::Scalar(ScalarValue::Int32(Some(5))), // scale ]) .unwrap() .into_array(5); diff --git a/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs b/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs index af5249ed..fa897195 100644 --- a/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs +++ b/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::StringArray; -use arrow::array::{new_null_array, Array}; -use arrow::datatypes::DataType; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_plan::ColumnarValue; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::{ + array::{new_null_array, Array, StringArray}, + datatypes::DataType, +}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_plan::ColumnarValue, +}; use datafusion_ext_commons::uda::UserDefinedArray; use itertools::Either; -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; /// implement hive/spark's UDFGetJson /// get_json_object(str, path) == get_parsed_json_object(parse_json(str), path) @@ -261,7 +263,8 @@ impl HiveGetJsonObjectMatcher { chars.next(); if chars.peek().cloned() == Some('[') { - return Self::parse(chars); // handle special case like $.aaa.[0].xxx + return Self::parse(chars); // handle special case like + // $.aaa.[0].xxx } let mut child_name = String::new(); loop { @@ -408,13 +411,14 @@ impl HiveGetJsonObjectMatcher { #[cfg(test)] mod test { + use std::sync::Arc; + + use arrow::array::{AsArray, StringArray}; + use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; + use crate::spark_get_json_object::{ spark_get_parsed_json_object, spark_parse_json, HiveGetJsonObjectEvaluator, }; - use arrow::array::{AsArray, StringArray}; - use datafusion::common::ScalarValue; - use datafusion::logical_expr::ColumnarValue; - use std::sync::Arc; #[test] fn test_hive_demo() { @@ -534,10 +538,10 @@ mod test { let input_array = Arc::new(StringArray::from(vec![input])); let parsed = spark_parse_json(&[ColumnarValue::Array(input_array)]).unwrap(); - //let path = ColumnarValue::Scalar(ScalarValue::from("$.NOT_EXISTED")); - //let r = spark_get_parsed_json_object(&[parsed.clone(), path]).unwrap().into_array(1); - //let v = r.as_string::().iter().next().unwrap(); - //assert_eq!(v, None); + // let path = ColumnarValue::Scalar(ScalarValue::from("$.NOT_EXISTED")); + // let r = spark_get_parsed_json_object(&[parsed.clone(), + // path]).unwrap().into_array(1); let v = r.as_string::().iter(). + // next().unwrap(); assert_eq!(v, None); let path = ColumnarValue::Scalar(ScalarValue::from("$.message.location.county")); let r = spark_get_parsed_json_object(&[parsed.clone(), path]) diff --git a/native-engine/datafusion-ext-functions/src/spark_make_array.rs b/native-engine/datafusion-ext-functions/src/spark_make_array.rs index 2d25450b..e8b57803 100644 --- a/native-engine/datafusion-ext-functions/src/spark_make_array.rs +++ b/native-engine/datafusion-ext-functions/src/spark_make_array.rs @@ -14,13 +14,15 @@ //! Array expressions -use arrow::array::*; -use arrow::datatypes::DataType; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; +use arrow::{array::*, datatypes::DataType}; +use datafusion::{ + common::{Result, ScalarValue}, + error::DataFusionError, + logical_expr::ColumnarValue, +}; + macro_rules! downcast_vec { ($ARGS:expr, $ARRAY_TYPE:ident) => {{ $ARGS @@ -109,9 +111,17 @@ fn array_array(args: &[ArrayRef]) -> Result { DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), data_type => { - return Err(DataFusionError::NotImplemented(format!( - "Array is not implemented for type '{data_type:?}'." - ))) + // naive implementation with scalar values + let num_rows = args[0].len(); + let mut output_scalars = Vec::with_capacity(num_rows); + for i in 0..num_rows { + let row_scalars: Vec = args + .iter() + .map(|arg| ScalarValue::try_from_array(arg, i)) + .collect::>()?; + output_scalars.push(ScalarValue::new_list(Some(row_scalars), data_type.clone())); + } + ScalarValue::iter_to_array(output_scalars)? } }; Ok(res) @@ -130,13 +140,16 @@ pub fn array(values: &[ColumnarValue]) -> Result { } #[cfg(test)] mod test { - use crate::spark_make_array::array; - use arrow::array::{ArrayRef, Int32Array, ListArray}; - use arrow::datatypes::{Float32Type, Int32Type}; - use datafusion::common::ScalarValue; - use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; + use arrow::{ + array::{ArrayRef, Int32Array, ListArray}, + datatypes::{Float32Type, Int32Type}, + }; + use datafusion::{common::ScalarValue, physical_plan::ColumnarValue}; + + use crate::spark_make_array::array; + #[test] fn test_make_array_int() { let result = array(&vec![ColumnarValue::Array(Arc::new(Int32Array::from( diff --git a/native-engine/datafusion-ext-functions/src/spark_make_decimal.rs b/native-engine/datafusion-ext-functions/src/spark_make_decimal.rs index 3e079f1f..259729ff 100644 --- a/native-engine/datafusion-ext-functions/src/spark_make_decimal.rs +++ b/native-engine/datafusion-ext-functions/src/spark_make_decimal.rs @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::*; -use datafusion::common::Result; -use datafusion::common::ScalarValue; -use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; +use arrow::array::*; +use datafusion::{ + common::{Result, ScalarValue}, + physical_plan::ColumnarValue, +}; + /// implements org.apache.spark.sql.catalyst.expressions.MakeDecimal pub fn spark_make_decimal(args: &[ColumnarValue]) -> Result { let precision = match &args[1] { @@ -59,12 +61,13 @@ pub fn spark_make_decimal(args: &[ColumnarValue]) -> Result { } #[cfg(test)] mod test { - use crate::spark_make_decimal::spark_make_decimal; - use arrow::array::{ArrayRef, Decimal128Array, Int64Array}; - use datafusion::common::ScalarValue; - use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; + use arrow::array::{ArrayRef, Decimal128Array, Int64Array}; + use datafusion::{common::ScalarValue, physical_plan::ColumnarValue}; + + use crate::spark_make_decimal::spark_make_decimal; + #[test] fn test_decimal() { let array = Int64Array::from(vec![ @@ -76,8 +79,8 @@ mod test { ]); let result = spark_make_decimal(&vec![ ColumnarValue::Array(Arc::new(array)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(10))), //precision - ColumnarValue::Scalar(ScalarValue::Int32(Some(5))), //scale + ColumnarValue::Scalar(ScalarValue::Int32(Some(10))), // precision + ColumnarValue::Scalar(ScalarValue::Int32(Some(5))), // scale ]) .unwrap() .into_array(5); diff --git a/native-engine/datafusion-ext-functions/src/spark_murmur3_hash.rs b/native-engine/datafusion-ext-functions/src/spark_murmur3_hash.rs index fad67c37..fff454a6 100644 --- a/native-engine/datafusion-ext-functions/src/spark_murmur3_hash.rs +++ b/native-engine/datafusion-ext-functions/src/spark_murmur3_hash.rs @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use arrow::array::*; -use datafusion::common::Result; -use datafusion::physical_plan::ColumnarValue; +use datafusion::{common::Result, physical_plan::ColumnarValue}; use datafusion_ext_commons::spark_hash::create_hashes; -use std::sync::Arc; /// implements org.apache.spark.sql.catalyst.expressions.UnscaledValue pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { @@ -49,10 +49,12 @@ pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { #[cfg(test)] mod test { - use crate::spark_murmur3_hash::spark_murmur3_hash; + use std::sync::Arc; + use arrow::array::{ArrayRef, Int32Array, Int64Array, StringArray}; use datafusion::logical_expr::ColumnarValue; - use std::sync::Arc; + + use crate::spark_murmur3_hash::spark_murmur3_hash; #[test] fn test_murmur3_hash_int64() { diff --git a/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs b/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs index a599a121..94dcd0a7 100644 --- a/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs +++ b/native-engine/datafusion-ext-functions/src/spark_null_if_zero.rs @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::*; -use arrow::compute::*; -use arrow::datatypes::*; -use datafusion::common::Result; -use datafusion::common::{DataFusionError, ScalarValue}; -use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; +use arrow::{array::*, compute::*, datatypes::*}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + physical_plan::ColumnarValue, +}; + /// used to avoid DivideByZero error in divide/modulo pub fn spark_null_if_zero(args: &[ColumnarValue]) -> Result { Ok(match &args[0] { @@ -86,12 +86,13 @@ pub fn spark_null_if_zero(args: &[ColumnarValue]) -> Result { #[cfg(test)] mod test { - use crate::spark_null_if_zero::spark_null_if_zero; - use arrow::array::{ArrayRef, Decimal128Array, Float32Array, Int32Array}; - use datafusion::common::ScalarValue; - use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; + use arrow::array::{ArrayRef, Decimal128Array, Float32Array, Int32Array}; + use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; + + use crate::spark_null_if_zero::spark_null_if_zero; + #[test] fn test_null_if_zero_int() { let result = spark_null_if_zero(&vec![ColumnarValue::Array(Arc::new(Int32Array::from( diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 31c7390c..95efb813 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -12,13 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::{Array, ArrayRef, ListArray, ListBuilder, StringArray, StringBuilder}; -use arrow::datatypes::DataType; -use datafusion::common::cast::{as_int32_array, as_list_array, as_string_array}; -use datafusion::common::{DataFusionError, Result, ScalarValue}; -use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; +use arrow::{ + array::{Array, ArrayRef, ListArray, ListBuilder, StringArray, StringBuilder}, + datatypes::DataType, +}; +use datafusion::{ + common::{ + cast::{as_int32_array, as_list_array, as_string_array}, + DataFusionError, Result, ScalarValue, + }, + physical_plan::ColumnarValue, +}; + pub fn string_lower(args: &[ColumnarValue]) -> Result { match &args[0] { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new(StringArray::from_iter( @@ -336,15 +343,21 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + + use arrow::array::{Int32Array, ListBuilder, StringArray, StringBuilder}; + use datafusion::{ + common::{ + cast::{as_list_array, as_string_array}, + Result, ScalarValue, + }, + physical_plan::ColumnarValue, + }; + use crate::spark_strings::{ string_concat, string_concat_ws, string_lower, string_repeat, string_space, string_split, string_upper, }; - use arrow::array::{Int32Array, ListBuilder, StringArray, StringBuilder}; - use datafusion::common::cast::{as_list_array, as_string_array}; - use datafusion::common::{Result, ScalarValue}; - use datafusion::physical_plan::ColumnarValue; - use std::sync::Arc; #[test] fn test_string_space() -> Result<()> { diff --git a/native-engine/datafusion-ext-functions/src/spark_unscaled_value.rs b/native-engine/datafusion-ext-functions/src/spark_unscaled_value.rs index 1068da76..6c9b3b10 100644 --- a/native-engine/datafusion-ext-functions/src/spark_unscaled_value.rs +++ b/native-engine/datafusion-ext-functions/src/spark_unscaled_value.rs @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::*; -use datafusion::common::Result; -use datafusion::common::ScalarValue; -use datafusion::physical_plan::ColumnarValue; use std::sync::Arc; +use arrow::array::*; +use datafusion::{ + common::{Result, ScalarValue}, + physical_plan::ColumnarValue, +}; + /// implements org.apache.spark.sql.catalyst.expressions.UnscaledValue pub fn spark_unscaled_value(args: &[ColumnarValue]) -> Result { Ok(match &args[0] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Decimal128(Some(v), _, _) => { + ScalarValue::Decimal128(Some(v), ..) => { ColumnarValue::Scalar(ScalarValue::Int64(Some(*v as i64))) } _ => ColumnarValue::Scalar(ScalarValue::Int64(None)), @@ -40,12 +42,13 @@ pub fn spark_unscaled_value(args: &[ColumnarValue]) -> Result { } #[cfg(test)] mod test { - use crate::spark_unscaled_value::spark_unscaled_value; - use arrow::array::{ArrayRef, Decimal128Array, Int64Array}; - use datafusion::common::ScalarValue; - use datafusion::logical_expr::ColumnarValue; use std::sync::Arc; + use arrow::array::{ArrayRef, Decimal128Array, Int64Array}; + use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; + + use crate::spark_unscaled_value::spark_unscaled_value; + #[test] fn test_unscaled_value_array() { let result = spark_unscaled_value(&vec![ColumnarValue::Array(Arc::new( diff --git a/native-engine/datafusion-ext-plans/Cargo.toml b/native-engine/datafusion-ext-plans/Cargo.toml index 6c793b14..768f9d3c 100644 --- a/native-engine/datafusion-ext-plans/Cargo.toml +++ b/native-engine/datafusion-ext-plans/Cargo.toml @@ -34,7 +34,4 @@ paste = "1.0.7" slimmer_box = "0.6.5" tempfile = "3" tokio = "1.34" -zstd = "0.12.3" - -[dev-dependencies] -rand = "0.8.5" \ No newline at end of file +zstd = "0.12.3" \ No newline at end of file diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_buf.rs b/native-engine/datafusion-ext-plans/src/agg/agg_buf.rs index 4d3a4b79..e8fef2b1 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_buf.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_buf.rs @@ -12,18 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::slim_bytes::SlimBytes; +use std::{ + any::Any, + collections::HashSet, + io::{Cursor, Read, Write}, + mem::{size_of, size_of_val}, +}; + use arrow::array::Array; use datafusion::common::{Result, ScalarValue}; -use datafusion_ext_commons::io::{ - read_array, read_bytes_slice, read_data_type, read_len, write_array, write_data_type, - write_len, write_u8, +use datafusion_ext_commons::{ + io::{ + read_array, read_bytes_slice, read_data_type, read_len, write_array, write_data_type, + write_len, write_u8, + }, + slim_bytes::SlimBytes, }; use slimmer_box::SlimmerBox; -use std::any::Any; -use std::collections::HashSet; -use std::io::{Cursor, Read, Write}; -use std::mem::{size_of, size_of_val}; #[derive(Eq, PartialEq)] pub struct AggBuf { @@ -176,7 +181,7 @@ pub fn create_agg_buf_from_initial_value( ScalarValue::Boolean(v) => handle_fixed!(v.map(|x| x as u8), 1), ScalarValue::Float32(v) => handle_fixed!(v, 4), ScalarValue::Float64(v) => handle_fixed!(v, 8), - ScalarValue::Decimal128(v, _, _) => handle_fixed!(v, 16), + ScalarValue::Decimal128(v, ..) => handle_fixed!(v, 16), ScalarValue::Int8(v) => handle_fixed!(v, 1), ScalarValue::Int16(v) => handle_fixed!(v, 2), ScalarValue::Int32(v) => handle_fixed!(v, 4), @@ -655,13 +660,14 @@ fn make_dyn_addr(idx: usize) -> u64 { #[cfg(test)] mod test { + use std::{collections::HashSet, io::Cursor}; + + use arrow::datatypes::DataType; + use datafusion::common::{Result, ScalarValue}; + use crate::agg::agg_buf::{ create_agg_buf_from_initial_value, AccumInitialValue, AggDynList, AggDynSet, AggDynStr, }; - use arrow::datatypes::DataType; - use datafusion::common::{Result, ScalarValue}; - use std::collections::HashSet; - use std::io::Cursor; #[test] fn test_dyn_list() { @@ -677,7 +683,11 @@ mod test { dyn_list.load(&mut Cursor::new(&mut buf)).unwrap(); assert_eq!( dyn_list.values, - vec![ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32),] + vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ] ); } @@ -698,15 +708,24 @@ mod test { assert_eq!( dyn_set.values, HashSet::from_iter( - vec![ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32),] - .into_iter() + vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ] + .into_iter() ) ); } #[test] fn test_agg_buf() { - let data_types = vec![DataType::Null, DataType::Int32, DataType::Int64, DataType::Utf8]; + let data_types = vec![ + DataType::Null, + DataType::Int32, + DataType::Int64, + DataType::Utf8, + ]; let scalars = data_types .iter() .map(|dt: &DataType| Ok(AccumInitialValue::Scalar(dt.clone().try_into()?))) diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_context.rs b/native-engine/datafusion-ext-plans/src/agg/agg_context.rs index 2967726b..cef0b5dd 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_context.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_context.rs @@ -12,19 +12,35 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{create_agg_buf_from_initial_value, AccumInitialValue, AggBuf}; -use crate::agg::{Agg, AggExecMode, AggExpr, AggMode, GroupingExpr, AGG_BUF_COLUMN_NAME}; -use crate::common::cached_exprs_evaluator::CachedExprsEvaluator; -use arrow::array::{Array, ArrayRef, BinaryArray, BinaryBuilder}; -use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use arrow::row::RowConverter; -use datafusion::common::cast::as_binary_array; -use datafusion::common::Result; -use datafusion::physical_expr::PhysicalExprRef; +use std::{ + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{ + array::{Array, ArrayRef, BinaryArray, BinaryBuilder}, + datatypes::{DataType, Field, Fields, Schema, SchemaRef}, + record_batch::{RecordBatch, RecordBatchOptions}, + row::{RowConverter, Rows, SortField}, +}; +use blaze_jni_bridge::{ + conf, + conf::{DoubleConf, IntConf}, +}; +use datafusion::{ + common::{cast::as_binary_array, Result}, + physical_expr::PhysicalExprRef, +}; use once_cell::sync::OnceCell; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use parking_lot::Mutex; + +use crate::{ + agg::{ + agg_buf::{create_agg_buf_from_initial_value, AccumInitialValue, AggBuf}, + Agg, AggExecMode, AggExpr, AggMode, GroupingExpr, AGG_BUF_COLUMN_NAME, + }, + common::cached_exprs_evaluator::CachedExprsEvaluator, +}; pub struct AggContext { pub exec_mode: AggExecMode, @@ -37,11 +53,15 @@ pub struct AggContext { pub grouping_schema: SchemaRef, pub agg_schema: SchemaRef, pub output_schema: SchemaRef, + pub grouping_row_converter: Arc>, pub groupings: Vec, pub aggs: Vec, pub initial_agg_buf: AggBuf, pub initial_input_agg_buf: AggBuf, pub initial_input_buffer_offset: usize, + pub supports_partial_skipping: bool, + pub partial_skipping_ratio: f64, + pub partial_skipping_min_rows: usize, pub agg_expr_evaluator: CachedExprsEvaluator, pub agg_expr_evaluator_output_schema: SchemaRef, @@ -64,6 +84,7 @@ impl AggContext { groupings: Vec, aggs: Vec, initial_input_buffer_offset: usize, + supports_partial_skipping: bool, ) -> Result { let grouping_schema = Arc::new(Schema::new( groupings @@ -77,6 +98,13 @@ impl AggContext { }) .collect::>()?, )); + let grouping_row_converter = Arc::new(Mutex::new(RowConverter::new( + grouping_schema + .fields() + .iter() + .map(|field| SortField::new(field.data_type().clone())) + .collect(), + )?)); // final aggregates may not exist along with partial/partial-merge let need_partial_update = aggs.iter().any(|agg| agg.mode == AggMode::Partial); @@ -111,7 +139,11 @@ impl AggContext { } let agg_schema = Arc::new(Schema::new(agg_fields)); let output_schema = Arc::new(Schema::new( - [grouping_schema.fields().to_vec(), agg_schema.fields().to_vec()].concat(), + [ + grouping_schema.fields().to_vec(), + agg_schema.fields().to_vec(), + ] + .concat(), )); let initial_accums: Box<[AccumInitialValue]> = aggs @@ -126,8 +158,8 @@ impl AggContext { // // Agg [groupings=[], aggs=[ // AggExpr { field_name: "#747", mode: PartialMerge, agg: Count(...) }, - // AggExpr { field_name: "#748", mode: Partial, agg: Count(Column { name: "#640", index: 0 }) } - // ]] + // AggExpr { field_name: "#748", mode: Partial, agg: Count(Column { name: + // "#640", index: 0 }) } ]] // Agg [groupings=[GroupingExpr { field_name: "#640", ...], aggs=[ // AggExpr { field_name: "#747", mode: PartialMerge, agg: Count(...) } // ]] @@ -173,6 +205,15 @@ impl AggContext { )); let agg_expr_evaluator = CachedExprsEvaluator::try_new(vec![], agg_exprs_flatten)?; + let (partial_skipping_ratio, partial_skipping_min_rows) = if supports_partial_skipping { + ( + conf::PARTIAL_AGG_SKIPPING_RATIO.value()?, + conf::PARTIAL_AGG_SKIPPING_MIN_ROWS.value()? as usize, + ) + } else { + Default::default() + }; + Ok(Self { exec_mode, need_partial_update, @@ -182,6 +223,7 @@ impl AggContext { need_partial_merge_aggs, output_schema, grouping_schema, + grouping_row_converter, agg_schema, groupings, aggs, @@ -191,11 +233,28 @@ impl AggContext { agg_expr_evaluator, agg_expr_evaluator_output_schema, initial_input_buffer_offset, + supports_partial_skipping, + partial_skipping_ratio, + partial_skipping_min_rows, agg_buf_addr_offsets: agg_buf_addr_offsets.into(), agg_buf_addr_counts: agg_buf_addr_counts.into(), }) } + pub fn create_grouping_rows(&self, input_batch: &RecordBatch) -> Result { + let grouping_arrays: Vec = self + .groupings + .iter() + .map(|grouping| grouping.expr.evaluate(&input_batch)) + .map(|r| r.map(|columnar| columnar.into_array(input_batch.num_rows()))) + .collect::>() + .map_err(|err| err.context("agg: evaluating grouping arrays error"))?; + Ok(self + .grouping_row_converter + .lock() + .convert_columns(&grouping_arrays)?) + } + pub fn create_input_arrays(&self, input_batch: &RecordBatch) -> Result>> { if !self.need_partial_update { return Ok(vec![]); @@ -262,10 +321,10 @@ impl AggContext { pub fn convert_records_to_batch( &self, - grouping_row_converter: &mut RowConverter, records: Vec<(impl AsRef<[u8]>, AggBuf)>, ) -> Result { let row_count = records.len(); + let grouping_row_converter = self.grouping_row_converter.lock(); let grouping_row_parser = grouping_row_converter.parser(); let grouping_columns = grouping_row_converter.convert_rows( records diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs b/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs index 4e4fc1db..7e6d3fce 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_tables.rs @@ -12,35 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -use ahash::RandomState; -use std::hash::{BuildHasher, Hash, Hasher}; -use std::io::{BufReader, Read, Write}; -use std::mem::{size_of, ManuallyDrop}; -use std::sync::{Arc, Weak}; +use std::{ + hash::{BuildHasher, Hash, Hasher}, + io::{BufReader, Read, Write}, + mem::{size_of, ManuallyDrop}, + sync::{Arc, Weak}, +}; -use arrow::row::{RowConverter, Rows}; +use ahash::RandomState; +use arrow::{record_batch::RecordBatch, row::Rows}; use async_trait::async_trait; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::execution::context::TaskContext; - -use datafusion::physical_plan::metrics::BaselineMetrics; +use datafusion::{ + common::Result, error::DataFusionError, execution::context::TaskContext, + physical_plan::metrics::BaselineMetrics, +}; +use datafusion_ext_commons::{ + bytes_arena::BytesArena, + io::{read_bytes_slice, read_len, write_len}, + loser_tree::LoserTree, + rdxsort, + slim_bytes::SlimBytes, +}; use futures::lock::Mutex; -use hashbrown::hash_map::{Entry, RawEntryMut}; -use hashbrown::HashMap; +use hashbrown::{ + hash_map::{Entry, RawEntryMut}, + HashMap, +}; use lz4_flex::frame::FrameDecoder; -use datafusion_ext_commons::io::{read_bytes_slice, read_len, write_len}; -use datafusion_ext_commons::loser_tree::LoserTree; - -use crate::agg::agg_buf::AggBuf; -use crate::agg::agg_context::AggContext; -use crate::common::bytes_arena::BytesArena; -use crate::common::memory_manager::{MemConsumer, MemConsumerInfo, MemManager}; -use crate::common::onheap_spill::{try_new_spill, Spill}; -use crate::common::output::WrappedRecordBatchSender; -use crate::common::rdxsort; -use crate::common::slim_bytes::SlimBytes; +use crate::{ + agg::{agg_buf::AggBuf, agg_context::AggContext}, + common::output::WrappedRecordBatchSender, + memmgr::{ + onheap_spill::{try_new_spill, Spill}, + MemConsumer, MemConsumerInfo, MemManager, + }, +}; // reserve memory for each spill // estimated size: bufread=64KB + lz4dec.src=64KB + lz4dec.dest=64KB @@ -74,7 +81,8 @@ impl AggTables { Self { name: format!("AggTable[partition={}]", partition_id), mem_consumer_info: None, - in_mem: Mutex::new(InMemTable::new(true)), // only the first im-mem table uses hash + // only the first im-mem table uses hash + in_mem: Mutex::new(InMemTable::new(agg_ctx.clone(), InMemMode::Hash)), spills: Mutex::default(), agg_ctx, context, @@ -82,17 +90,33 @@ impl AggTables { } } - pub async fn update_entries( - &self, - key_rows: Rows, - fn_entries: impl Fn(&mut [AggBuf]) -> Result, - ) -> Result<()> { + pub async fn process_input_batch(&self, input_batch: RecordBatch) -> Result<()> { let mut in_mem = self.in_mem.lock().await; - in_mem.update_entries(&self.agg_ctx, key_rows, fn_entries)?; + // compute grouping rows + let grouping_rows = self.agg_ctx.create_grouping_rows(&input_batch)?; + + // compute input arrays + let input_arrays = self.agg_ctx.create_input_arrays(&input_batch)?; + let agg_buf_array = self.agg_ctx.get_input_agg_buf_array(&input_batch)?; + in_mem.update_entries(grouping_rows, |agg_bufs| { + let mut mem_diff = 0; + mem_diff += self + .agg_ctx + .partial_batch_update_input(agg_bufs, &input_arrays)?; + mem_diff += self + .agg_ctx + .partial_batch_merge_input(agg_bufs, agg_buf_array)?; + Ok(mem_diff) + })?; let mem_used = in_mem.mem_used(); drop(in_mem); - self.update_mem_used(mem_used).await?; + + // if triggered partial skipping, no need to update memory usage and try to + // spill + if self.mode().await != InMemMode::PartialSkipped { + self.update_mem_used(mem_used).await?; + } Ok(()) } @@ -100,16 +124,53 @@ impl AggTables { !self.spills.lock().await.is_empty() } + pub async fn mode(&self) -> InMemMode { + self.in_mem.lock().await.mode + } + + pub async fn renew_in_mem_table(&self, mode: InMemMode) -> InMemTable { + let mut old = self.in_mem.lock().await; + let new = InMemTable::new(self.agg_ctx.clone(), mode); + std::mem::replace(&mut *old, new) + } + + pub async fn process_partial_skipped( + &self, + input_batch: RecordBatch, + baseline_metrics: BaselineMetrics, + sender: Arc, + ) -> Result<()> { + self.set_spillable(false); + let mut timer = baseline_metrics.elapsed_compute().timer(); + + let old_in_mem = self.renew_in_mem_table(InMemMode::PartialSkipped).await; + assert_eq!(old_in_mem.num_records(), 0); // old table must be cleared + + self.process_input_batch(input_batch).await?; + let in_mem = self.renew_in_mem_table(InMemMode::PartialSkipped).await; + let records = in_mem + .unsorted_keys + .iter() + .flat_map(|rows| rows.iter()) + .zip(in_mem.unsorted_values.into_iter()) + .collect::>(); + let batch = self.agg_ctx.convert_records_to_batch(records)?; + + baseline_metrics.record_output(batch.num_rows()); + sender.send(Ok(batch), Some(&mut timer)).await; + self.update_mem_used(0).await?; + return Ok(()); + } + pub async fn output( &self, - mut grouping_row_converter: RowConverter, baseline_metrics: BaselineMetrics, sender: Arc, ) -> Result<()> { self.set_spillable(false); let mut timer = baseline_metrics.elapsed_compute().timer(); - let in_mem = std::mem::replace(&mut *self.in_mem.lock().await, InMemTable::new(true)); + let in_mem = self.renew_in_mem_table(InMemMode::PartialSkipped).await; let spills = std::mem::take(&mut *self.spills.lock().await); let batch_size = self.context.session_config().batch_size(); @@ -131,9 +192,7 @@ impl AggTables { let chunk = records.split_off(records.len().saturating_sub(batch_size)); records.shrink_to_fit(); - let batch = self - .agg_ctx - .convert_records_to_batch(&mut grouping_row_converter, chunk)?; + let batch = self.agg_ctx.convert_records_to_batch(chunk)?; let batch_mem_size = batch.get_array_memory_size(); baseline_metrics.record_output(batch.num_rows()); @@ -167,10 +226,9 @@ impl AggTables { macro_rules! flush_staging { () => {{ - let batch = self.agg_ctx.convert_records_to_batch( - &mut grouping_row_converter, - std::mem::take(&mut staging_records), - )?; + let batch = self + .agg_ctx + .convert_records_to_batch(std::mem::take(&mut staging_records))?; baseline_metrics.record_output(batch.num_rows()); sender.send(Ok(batch), Some(&mut timer)).await; }}; @@ -259,11 +317,16 @@ impl MemConsumer for AggTables { let mut in_mem = self.in_mem.lock().await; let mut spills = self.spills.lock().await; - spills.extend(std::mem::replace(&mut *in_mem, InMemTable::new(false)).try_into_spill()?); - drop(spills); - drop(in_mem); - - self.update_mem_used(0).await?; + // do not spill anything if triggered partial skipping + // regardless minRows configuration + in_mem.check_trigger_partial_skipping(); + if in_mem.mode != InMemMode::PartialSkipped { + let replaced_in_mem = InMemTable::new(self.agg_ctx.clone(), InMemMode::Merge); + spills.extend(std::mem::replace(&mut *in_mem, replaced_in_mem).try_into_spill()?); + drop(spills); + drop(in_mem); + self.update_mem_used(0).await?; + } Ok(()) } } @@ -274,15 +337,24 @@ impl Drop for AggTables { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InMemMode { + Hash, + Merge, + PartialSkipped, +} + /// Unordered in-mem hash table which can be updated pub struct InMemTable { + agg_ctx: Arc, map_keys: Box, map: HashMap, unsorted_keys: Vec, unsorted_values: Vec, unsorted_keys_mem_used: usize, agg_buf_mem_used: usize, - pub is_hash: bool, + num_input_records: usize, + mode: InMemMode, } // a hasher for hashing addrs got from map_keys @@ -324,20 +396,22 @@ impl BuildHasher for MapKeyHashBuilder { unsafe impl Send for MapKeyHashBuilder {} impl InMemTable { - fn new(is_hash: bool) -> Self { + fn new(agg_ctx: Arc, mode: InMemMode) -> Self { let map_keys: Box = Box::default(); let map_key_hash_builder = MapKeyHashBuilder(unsafe { // safety: hash builder's lifetime is shorter than map_keys std::mem::transmute::<_, &'static BytesArena>(map_keys.as_ref()) }); Self { + agg_ctx, map_keys, map: HashMap::with_hasher(map_key_hash_builder), unsorted_keys: vec![], unsorted_values: vec![], unsorted_keys_mem_used: 0, agg_buf_mem_used: 0, - is_hash, + num_input_records: 0, + mode, } } @@ -359,20 +433,24 @@ impl InMemTable { pub fn update_entries( &mut self, - agg_ctx: &Arc, key_rows: Rows, fn_entries: impl Fn(&mut [AggBuf]) -> Result, ) -> Result<()> { - if self.is_hash { - self.update_hash_entries(agg_ctx, key_rows, fn_entries) + let num_input_records = key_rows.num_rows(); + let mem_diff = if self.mode == InMemMode::Hash { + self.update_hash_entries(key_rows, fn_entries)? } else { - self.update_unsorted_entries(agg_ctx, key_rows, fn_entries) + self.update_unsorted_entries(key_rows, fn_entries)? + }; + self.num_input_records += num_input_records; + if self.num_input_records >= self.agg_ctx.partial_skipping_min_rows { + self.check_trigger_partial_skipping(); } + Ok(mem_diff) } fn update_hash_entries( &mut self, - agg_ctx: &Arc, key_rows: Rows, fn_entries: impl Fn(&mut [AggBuf]) -> Result, ) -> Result<()> { @@ -396,7 +474,7 @@ impl InMemTable { } RawEntryMut::Vacant(view) => { let new_key_addr = self.map_keys.add(row.as_ref()); - let new_entry = agg_ctx.initial_agg_buf.clone(); + let new_entry = self.agg_ctx.initial_agg_buf.clone(); // safety: agg_buf lives longer than this function call. // items in agg_bufs are later moved into ManuallyDrop to avoid double drop. agg_bufs.push(unsafe { std::ptr::read(&new_entry as *const AggBuf) }); @@ -414,78 +492,98 @@ impl InMemTable { fn update_unsorted_entries( &mut self, - agg_ctx: &Arc, key_rows: Rows, fn_entries: impl Fn(&mut [AggBuf]) -> Result, ) -> Result<()> { let beg = self.unsorted_values.len(); let len = key_rows.num_rows(); self.unsorted_values - .extend((0..len).map(|_| agg_ctx.initial_agg_buf.clone())); + .extend((0..len).map(|_| self.agg_ctx.initial_agg_buf.clone())); self.agg_buf_mem_used += fn_entries(&mut self.unsorted_values[beg..][..len])?; self.unsorted_keys_mem_used += key_rows.size(); self.unsorted_keys.push(key_rows); Ok(()) } + fn check_trigger_partial_skipping(&mut self) { + if self.agg_ctx.supports_partial_skipping && self.mode != InMemMode::PartialSkipped { + let num_input_records = self.num_input_records; + let num_records = self.num_records(); + let cardinality_ratio = num_records as f64 / num_input_records as f64; + if cardinality_ratio > self.agg_ctx.partial_skipping_ratio { + log::warn!( + "Agg: cardinality ratio = {cardinality_ratio}, will trigger partial skipping" + ); + self.mode = InMemMode::PartialSkipped; + } + } + } + fn try_into_spill(self) -> Result>> { if self.map.is_empty() && self.unsorted_values.is_empty() { return Ok(None); } + fn write_spill( + bucket_counts: &[usize], + bucketed_records: &mut [(u16, impl AsRef<[u8]>, AggBuf)], + ) -> Result> { + let spill = try_new_spill()?; + let mut writer = lz4_flex::frame::FrameEncoder::new(spill.get_buf_writer()); + let mut beg = 0; + + for i in 0..65536 { + if bucket_counts[i] > 0 { + // write bucket id and number of records in this bucket + write_len(i, &mut writer)?; + write_len(bucket_counts[i], &mut writer)?; + + // write records in this bucket + for (_, key, value) in &mut bucketed_records[beg..][..bucket_counts[i]] { + // write key + let key = key.as_ref(); + write_len(key.len(), &mut writer)?; + writer.write_all(key)?; + + // write value + value.save(&mut writer)?; + } + } + beg += bucket_counts[i]; + } + write_len(65536, &mut writer)?; // EOF + write_len(0, &mut writer)?; + writer + .finish() + .map_err(|err| DataFusionError::Execution(format!("{}", err)))?; + spill.complete()?; + Ok(spill) + } + // sort all records using radix sort on hashcodes of keys - let mut sorted: Vec<(u16, &[u8], AggBuf)> = if self.is_hash { - self.map + let spill = if self.mode == InMemMode::Hash { + let mut sorted = self + .map .into_iter() .map(|(key_addr, value)| { let key = self.map_keys.get(key_addr); let key_hash = RANDOM_STATE.hash_one(key) as u16; (key_hash, key, value) }) - .collect() + .collect::>(); + let bucket_counts = rdxsort::radix_sort_u16_by(&mut sorted, |(h, ..)| *h); + write_spill(&bucket_counts, &mut sorted)? } else { - self.unsorted_values - .into_iter() - .zip(self.unsorted_keys.iter().flat_map(|rows| { - rows.iter().map(|row| { - // safety - row bytes has same lifetime with self.unsorted_rows - unsafe { std::mem::transmute::<_, &'static [u8]>(row.as_ref()) } - }) - })) - .map(|(value, key)| (RANDOM_STATE.hash_one(key) as u16, key, value)) - .collect() + let mut sorted = self + .unsorted_keys + .iter() + .flat_map(|rows| rows.iter()) + .zip(self.unsorted_values.into_iter()) + .map(|(key, value)| (RANDOM_STATE.hash_one(key.as_ref()) as u16, key, value)) + .collect::>(); + let bucket_counts = rdxsort::radix_sort_u16_by(&mut sorted, |(h, ..)| *h); + write_spill(&bucket_counts, &mut sorted)? }; - let counts = rdxsort::radix_sort_u16_by(&mut sorted, |(h, _, _)| *h); - - let spill = try_new_spill()?; - let mut writer = lz4_flex::frame::FrameEncoder::new(spill.get_buf_writer()); - let mut beg = 0; - - for i in 0..65536 { - if counts[i] > 0 { - // write bucket id and number of records in this bucket - write_len(i, &mut writer)?; - write_len(counts[i], &mut writer)?; - - // write records in this bucket - for (_, key, value) in &mut sorted[beg..][..counts[i]] { - // write key - write_len(key.len(), &mut writer)?; - writer.write_all(key)?; - - // write value - value.save(&mut writer)?; - } - beg += counts[i]; - } - } - write_len(65536, &mut writer)?; // EOF - write_len(0, &mut writer)?; - - writer - .finish() - .map_err(|err| DataFusionError::Execution(format!("{}", err)))?; - spill.complete()?; Ok(Some(spill)) } } diff --git a/native-engine/datafusion-ext-plans/src/agg/avg.rs b/native-engine/datafusion-ext-plans/src/agg/avg.rs index f63c08d5..fb541fba 100644 --- a/native-engine/datafusion-ext-plans/src/agg/avg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/avg.rs @@ -12,19 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf}; -use crate::agg::count::AggCount; -use crate::agg::sum::AggSum; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::cast::{as_decimal128_array, as_int64_array}; -use datafusion::common::{Result, ScalarValue}; -use datafusion::error::DataFusionError; -use datafusion::physical_expr::PhysicalExpr; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{ + cast::{as_decimal128_array, as_int64_array}, + Result, ScalarValue, + }, + error::DataFusionError, + physical_expr::PhysicalExpr, +}; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf}, + count::AggCount, + sum::AggSum, + Agg, +}; pub struct AggAvg { child: Arc, @@ -212,7 +221,7 @@ impl Agg for AggAvg { fn get_final_merger(dt: &DataType) -> Result ScalarValue> { macro_rules! get_fn { - ($ty:ident, f64) => {{ + ($ty:ident,f64) => {{ Ok(|sum: ScalarValue, count: i64| { let avg = match sum { ScalarValue::$ty(sum, ..) => ScalarValue::Float64(if !count.is_zero() { diff --git a/native-engine/datafusion-ext-plans/src/agg/collect_list.rs b/native-engine/datafusion-ext-plans/src/agg/collect_list.rs index d97a0014..42d8ab9d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect_list.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect_list.rs @@ -12,15 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynList}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_expr::PhysicalExpr; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExpr, +}; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf, AggDynList}, + Agg, +}; pub struct AggCollectList { child: Arc, diff --git a/native-engine/datafusion-ext-plans/src/agg/collect_set.rs b/native-engine/datafusion-ext-plans/src/agg/collect_set.rs index 3941a4ef..c1c274a3 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect_set.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect_set.rs @@ -12,15 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynSet}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_expr::PhysicalExpr; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExpr, +}; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf, AggDynSet}, + Agg, +}; pub struct AggCollectSet { child: Arc, diff --git a/native-engine/datafusion-ext-plans/src/agg/count.rs b/native-engine/datafusion-ext-plans/src/agg/count.rs index 80a6e1f6..54b016db 100644 --- a/native-engine/datafusion-ext-plans/src/agg/count.rs +++ b/native-engine/datafusion-ext-plans/src/agg/count.rs @@ -12,15 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_expr::PhysicalExpr; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExpr, +}; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf}, + Agg, +}; pub struct AggCount { child: Arc, diff --git a/native-engine/datafusion-ext-plans/src/agg/first.rs b/native-engine/datafusion-ext-plans/src/agg/first.rs index 588ef7d1..ccb5b668 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first.rs @@ -12,16 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExpr, +}; use paste::paste; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}, + Agg, +}; pub struct AggFirst { child: Arc, @@ -168,7 +175,7 @@ fn get_partial_updater(dt: &DataType) -> Result fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok( |agg_buf: &mut AggBuf, addrs: &[u64], v: &ArrayRef, i: usize| { let w = AggDynStr::value_mut(agg_buf.dyn_value_mut(addrs[0])); @@ -246,7 +253,7 @@ fn get_partial_buf_merger(dt: &DataType) -> Result fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok(|agg_buf1, agg_buf2, addrs| { if is_touched(agg_buf2, addrs) { let w = AggDynStr::value_mut(agg_buf1.dyn_value_mut(addrs[0])); diff --git a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs index ee0eb981..2b3552aa 100644 --- a/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs +++ b/native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs @@ -12,16 +12,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExpr, +}; use paste::paste; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}, + Agg, +}; pub struct AggFirstIgnoresNull { child: Arc, @@ -33,7 +40,9 @@ pub struct AggFirstIgnoresNull { impl AggFirstIgnoresNull { pub fn try_new(child: Arc, data_type: DataType) -> Result { - let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from(&data_type)?)]; + let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from( + &data_type, + )?)]; let partial_updater = get_partial_updater(&data_type)?; let partial_buf_merger = get_partial_buf_merger(&data_type)?; Ok(Self { @@ -156,7 +165,7 @@ fn get_partial_updater(dt: &DataType) -> Result fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok(|agg_buf: &mut AggBuf, addr: u64, v: &ArrayRef, i: usize| { let w = AggDynStr::value_mut(agg_buf.dyn_value_mut(addr)); if w.is_none() && v.is_valid(i) { @@ -218,7 +227,7 @@ fn get_partial_buf_merger(dt: &DataType) -> Result fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok(|agg_buf1, agg_buf2, addr| { let w = AggDynStr::value_mut(agg_buf1.dyn_value_mut(addr)); if w.is_none() { diff --git a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs index 65f1a775..d87b2780 100644 --- a/native-engine/datafusion-ext-plans/src/agg/maxmin.rs +++ b/native-engine/datafusion-ext-plans/src/agg/maxmin.rs @@ -12,19 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynStr}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::error::DataFusionError; -use datafusion::physical_expr::PhysicalExpr; +use std::{ + any::Any, + cmp::Ordering, + fmt::{Debug, Formatter}, + marker::PhantomData, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + error::DataFusionError, + physical_expr::PhysicalExpr, +}; use paste::paste; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::{Debug, Formatter}; -use std::marker::PhantomData; -use std::sync::Arc; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf, AggDynStr}, + Agg, +}; pub type AggMax = AggMaxMin; pub type AggMin = AggMaxMin; @@ -41,7 +48,9 @@ pub struct AggMaxMin { impl AggMaxMin

{ pub fn try_new(child: Arc, data_type: DataType) -> Result { - let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from(&data_type)?)]; + let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from( + &data_type, + )?)]; let partial_updater = get_partial_updater::

(&data_type)?; let partial_batch_updater = get_partial_batch_updater::

(&data_type)?; let partial_buf_merger = get_partial_buf_merger::

(&data_type)?; @@ -177,7 +186,7 @@ impl Agg for AggMaxMin

{ DataType::Timestamp(TimeUnit::Nanosecond, _) => { handle_fixed!(TimestampNanosecond, P::maxmin) } - DataType::Decimal128(_, _) => handle_fixed!(Decimal128, P::maxmin), + DataType::Decimal128(..) => handle_fixed!(Decimal128, P::maxmin), DataType::Utf8 => { let value = values[0].as_any().downcast_ref::().unwrap(); if let Some(max) = P::maxmin_string(value) { @@ -295,7 +304,7 @@ fn get_partial_updater( DataType::Timestamp(TimeUnit::Millisecond, _) => fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok(|agg_buf: &mut AggBuf, addr: u64, v: &ArrayRef, i: usize| { let value = v.as_any().downcast_ref::().unwrap(); if value.is_valid(i) { @@ -353,7 +362,7 @@ fn get_partial_batch_updater( DataType::Timestamp(TimeUnit::Millisecond, _) => fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok(|agg_bufs: &mut [AggBuf], addr: u64, v: &ArrayRef| { let value = v.as_any().downcast_ref::().unwrap(); for (agg_buf, v) in agg_bufs.iter_mut().zip(value.iter()) { @@ -416,7 +425,7 @@ fn get_partial_buf_merger( DataType::Timestamp(TimeUnit::Millisecond, _) => fn_fixed!(TimestampMillisecond), DataType::Timestamp(TimeUnit::Microsecond, _) => fn_fixed!(TimestampMicrosecond), DataType::Timestamp(TimeUnit::Nanosecond, _) => fn_fixed!(TimestampNanosecond), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), DataType::Utf8 => Ok(|agg_buf1, agg_buf2, addr| { let v = AggDynStr::value(agg_buf2.dyn_value_mut(addr)); if v.is_some() { diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index edaaad28..a4986b90 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -24,16 +24,17 @@ pub mod first_ignores_null; pub mod maxmin; pub mod sum; -use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{DataFusionError, Result, ScalarValue}; -use datafusion::logical_expr::aggregate_function; -use datafusion::physical_expr::PhysicalExpr; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{DataFusionError, Result, ScalarValue}, + logical_expr::aggregate_function, + physical_expr::PhysicalExpr, +}; use datafusion_ext_exprs::cast::TryCastExpr; -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; + +use crate::agg::agg_buf::{AccumInitialValue, AggBuf, AggDynBinary, AggDynScalar, AggDynStr}; pub const AGG_BUF_COLUMN_NAME: &str = "#9223372036854775807"; diff --git a/native-engine/datafusion-ext-plans/src/agg/sum.rs b/native-engine/datafusion-ext-plans/src/agg/sum.rs index 4afef71b..a9f71bbc 100644 --- a/native-engine/datafusion-ext-plans/src/agg/sum.rs +++ b/native-engine/datafusion-ext-plans/src/agg/sum.rs @@ -12,19 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{AccumInitialValue, AggBuf}; -use crate::agg::Agg; -use arrow::array::*; -use arrow::datatypes::*; -use datafusion::common::{Result, ScalarValue}; -use datafusion::error::DataFusionError; +use std::{ + any::Any, + fmt::{Debug, Formatter}, + ops::Add, + sync::Arc, +}; -use datafusion::physical_expr::PhysicalExpr; +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + error::DataFusionError, + physical_expr::PhysicalExpr, +}; use paste::paste; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::ops::Add; -use std::sync::Arc; + +use crate::agg::{ + agg_buf::{AccumInitialValue, AggBuf}, + Agg, +}; pub struct AggSum { child: Arc, @@ -37,7 +43,9 @@ pub struct AggSum { impl AggSum { pub fn try_new(child: Arc, data_type: DataType) -> Result { - let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from(&data_type)?)]; + let accums_initial = vec![AccumInitialValue::Scalar(ScalarValue::try_from( + &data_type, + )?)]; let partial_updater = get_partial_updater(&data_type)?; let partial_batch_updater = get_partial_batch_updater(&data_type)?; let partial_buf_merger = get_partial_buf_merger(&data_type)?; @@ -286,7 +294,7 @@ fn get_partial_buf_merger(dt: &DataType) -> Result fn_fixed!(UInt16), DataType::UInt32 => fn_fixed!(UInt32), DataType::UInt64 => fn_fixed!(UInt64), - DataType::Decimal128(_, _) => fn_fixed!(Decimal128), + DataType::Decimal128(..) => fn_fixed!(Decimal128), other => Err(DataFusionError::NotImplemented(format!( "unsupported data type in sum(): {}", other diff --git a/native-engine/datafusion-ext-plans/src/agg_exec.rs b/native-engine/datafusion-ext-plans/src/agg_exec.rs index 65f6b4d1..ed0bf21f 100644 --- a/native-engine/datafusion-ext-plans/src/agg_exec.rs +++ b/native-engine/datafusion-ext-plans/src/agg_exec.rs @@ -12,34 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::ArrayRef; -use arrow::datatypes::{FieldRef, SchemaRef}; -use arrow::error::ArrowError; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use arrow::row::{RowConverter, SortField}; -use datafusion::common::{Result, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{ + datatypes::SchemaRef, + error::ArrowError, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use datafusion::{ + common::{Result, Statistics}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + }, +}; +use datafusion_ext_commons::{slim_bytes::SlimBytes, streams::coalesce_stream::CoalesceInput}; +use futures::{stream::once, StreamExt, TryFutureExt, TryStreamExt}; + +use crate::{ + agg::{ + agg_buf::AggBuf, + agg_context::AggContext, + agg_tables::{AggTables, InMemMode}, + AggExecMode, AggExpr, GroupingExpr, + }, + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + output::TaskOutputter, + }, + memmgr::MemManager, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use futures::stream::once; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; - -use crate::agg::agg_buf::AggBuf; -use crate::agg::agg_context::AggContext; -use crate::agg::agg_tables::AggTables; -use crate::agg::{AggExecMode, AggExpr, GroupingExpr}; -use crate::common::batch_statisitcs::{stat_input, InputBatchStatistics}; -use crate::common::memory_manager::MemManager; -use crate::common::output::{output_bufferable_with_spill, output_with_sender}; -use crate::common::slim_bytes::SlimBytes; #[derive(Debug)] pub struct AggExec { @@ -54,6 +63,7 @@ impl AggExec { groupings: Vec, aggs: Vec, initial_input_buffer_offset: usize, + supports_partial_skipping: bool, input: Arc, ) -> Result { let agg_ctx = Arc::new(AggContext::try_new( @@ -62,6 +72,7 @@ impl AggExec { groupings, aggs, initial_input_buffer_offset, + supports_partial_skipping, )?); Ok(Self { @@ -173,16 +184,6 @@ async fn execute_agg_with_grouping_hash( let baseline_metrics = BaselineMetrics::new(&metrics, partition_id); let timer = baseline_metrics.elapsed_compute().timer(); - // create grouping row converter and parser - let mut grouping_row_converter = RowConverter::new( - agg_ctx - .grouping_schema - .fields() - .iter() - .map(|field: &FieldRef| SortField::new(field.data_type().clone())) - .collect(), - )?; - // create tables let tables = Arc::new(AggTables::new( partition_id, @@ -198,13 +199,9 @@ async fn execute_agg_with_grouping_hash( InputBatchStatistics::from_metrics_set_and_blaze_conf(&metrics, partition_id)?, input.execute(partition_id, context.clone())?, )?; - let mut coalesced = Box::pin(CoalesceStream::new( - input, - context.session_config().batch_size(), - BaselineMetrics::new(&metrics, partition_id) - .elapsed_compute() - .clone(), - )); + let mut coalesced = context + .coalesce_with_default_batch_size(input, &BaselineMetrics::new(&metrics, partition_id))?; + while let Some(input_batch) = coalesced .next() .await @@ -213,54 +210,42 @@ async fn execute_agg_with_grouping_hash( { let _timer = baseline_metrics.elapsed_compute().timer(); - // compute grouping rows - let grouping_arrays: Vec = agg_ctx - .groupings - .iter() - .map(|grouping: &GroupingExpr| grouping.expr.evaluate(&input_batch)) - .map(|r| r.map(|columnar| columnar.into_array(input_batch.num_rows()))) - .collect::>() - .map_err(|err| err.context("agg: evaluating grouping arrays error"))?; - let grouping_rows = grouping_row_converter.convert_columns(&grouping_arrays)?; - - // compute input arrays - let input_arrays = agg_ctx - .create_input_arrays(&input_batch) - .map_err(|err| err.context("agg: evaluating input arrays error"))?; - let agg_buf_array = agg_ctx - .get_input_agg_buf_array(&input_batch) - .map_err(|err| err.context("agg: evaluating input agg-buf arrays error"))?; - // insert or update rows into in-mem table - tables - .update_entries(grouping_rows, |agg_bufs| { - let mut mem_diff = 0; - mem_diff += agg_ctx.partial_batch_update_input(agg_bufs, &input_arrays)?; - mem_diff += agg_ctx.partial_batch_merge_input(agg_bufs, agg_buf_array)?; - Ok(mem_diff) - }) - .await?; + tables.process_input_batch(input_batch).await?; + + // stop aggregating if triggered partial skipping + if tables.mode().await == InMemMode::PartialSkipped { + break; + } } let has_spill = tables.has_spill().await; let tables_cloned = tables.clone(); // merge all tables and output - let output = output_with_sender( - "Agg", - context.clone(), - agg_ctx.output_schema.clone(), - |sender| async move { + let output_schema = agg_ctx.output_schema.clone(); + let output = context.output_with_sender("Agg", output_schema, |sender| async move { + // output all aggregated records in table + tables + .output(baseline_metrics.clone(), sender.clone()) + .await?; + + // in partial skipping mode, there might be unconsumed records in input stream + while let Some(input_batch) = coalesced + .next() + .await + .transpose() + .map_err(|err| err.context("agg: polling batches from input error"))? + { tables - .output(grouping_row_converter, baseline_metrics, sender) - .await - .map_err(|err| err.context("agg: executing output error"))?; - Ok(()) - }, - )?; + .process_partial_skipped(input_batch, baseline_metrics.clone(), sender.clone()) + .await?; + } + Ok(()) + })?; // if running in-memory, buffer output when memory usage is high if !has_spill { - return output_bufferable_with_spill(tables_cloned, context, output); + return context.output_bufferable_with_spill(tables_cloned, output); } Ok(output) } @@ -280,61 +265,37 @@ async fn execute_agg_no_grouping( InputBatchStatistics::from_metrics_set_and_blaze_conf(&metrics, partition_id)?, input.execute(partition_id, context.clone())?, )?; - let mut coalesced = Box::pin(CoalesceStream::new( - input, - context.session_config().batch_size(), - baseline_metrics.elapsed_compute().clone(), - )); + let mut coalesced = context.coalesce_with_default_batch_size(input, &baseline_metrics)?; while let Some(input_batch) = coalesced.next().await.transpose()? { let elapsed_compute = baseline_metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); - let input_arrays = agg_ctx - .create_input_arrays(&input_batch) - .map_err(|err| err.context("agg: evaluating input arrays error"))?; - agg_ctx - .partial_update_input_all(&mut agg_buf, &input_arrays) - .map_err(|err| err.context("agg: executing partial_update_input_all() error"))?; - - let agg_buf_array = agg_ctx - .get_input_agg_buf_array(&input_batch) - .map_err(|err| err.context("agg: evaluating input agg-buf arrays error"))?; - agg_ctx - .partial_merge_input_all(&mut agg_buf, agg_buf_array) - .map_err(|err| err.context("agg: executing partial_merge_input_all() error"))?; + let input_arrays = agg_ctx.create_input_arrays(&input_batch)?; + let agg_buf_array = agg_ctx.get_input_agg_buf_array(&input_batch)?; + agg_ctx.partial_update_input_all(&mut agg_buf, &input_arrays)?; + agg_ctx.partial_merge_input_all(&mut agg_buf, agg_buf_array)?; } // output // in no-grouping mode, we always output only one record, so it is not // necessary to record elapsed computed time. - output_with_sender( - "Agg", - context, - agg_ctx.output_schema.clone(), - move |sender| async move { - let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - let mut timer = elapsed_compute.timer(); - - let batch_result = agg_ctx - .build_agg_columns(vec![(&[], agg_buf)]) - .map_err(|e| ArrowError::ExternalError(Box::new(e))) - .and_then(|agg_columns| { - RecordBatch::try_new_with_options( - agg_ctx.output_schema.clone(), - agg_columns, - &RecordBatchOptions::new().with_row_count(Some(1)), - ) - .map(|batch| { - baseline_metrics.record_output(1); - batch - }) - }); - sender.send(Ok(batch_result?), Some(&mut timer)).await; - log::info!("aggregate exec (no grouping) outputting one record"); - Ok(()) - }, - ) + let output_schema = agg_ctx.output_schema.clone(); + context.output_with_sender("Agg", output_schema, move |sender| async move { + let elapsed_compute = baseline_metrics.elapsed_compute().clone(); + let mut timer = elapsed_compute.timer(); + + let agg_columns = agg_ctx.build_agg_columns(vec![(&[], agg_buf)])?; + let batch = RecordBatch::try_new_with_options( + agg_ctx.output_schema.clone(), + agg_columns, + &RecordBatchOptions::new().with_row_count(Some(1)), + )?; + baseline_metrics.record_output(1); + sender.send(Ok(batch), Some(&mut timer)).await; + log::info!("aggregate exec (no grouping) outputting one record"); + Ok(()) + }) } async fn execute_agg_sorted( @@ -346,132 +307,103 @@ async fn execute_agg_sorted( ) -> Result { let baseline_metrics = BaselineMetrics::new(&metrics, partition_id); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - - // create grouping row converter and parser - let mut grouping_row_converter = RowConverter::new( - agg_ctx - .grouping_schema - .fields() - .iter() - .map(|field: &FieldRef| SortField::new(field.data_type().clone())) - .collect(), - )?; + let batch_size = context.session_config().batch_size(); // start processing input batches let input = stat_input( InputBatchStatistics::from_metrics_set_and_blaze_conf(&metrics, partition_id)?, input.execute(partition_id, context.clone())?, )?; - let mut coalesced = Box::pin(CoalesceStream::new( - input, - context.session_config().batch_size(), - baseline_metrics.elapsed_compute().clone(), - )); - output_with_sender( - "Agg", - context.clone(), - agg_ctx.output_schema.clone(), - |sender| async move { - let batch_size = context.session_config().batch_size(); - let mut staging_records = vec![]; - let mut current_record: Option<(SlimBytes, AggBuf)> = None; - let mut timer = elapsed_compute.timer(); - timer.stop(); - - macro_rules! flush_staging { - () => {{ - let batch = agg_ctx.convert_records_to_batch( - &mut grouping_row_converter, - std::mem::take(&mut staging_records), - )?; - log::info!( - "aggregate exec (sorted) outputting one batch: num_rows={}", - batch.num_rows(), - ); - baseline_metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - }}; - } - while let Some(input_batch) = coalesced.next().await.transpose()? { - timer.restart(); - - // compute grouping rows - let grouping_arrays: Vec = agg_ctx - .groupings - .iter() - .map(|grouping: &GroupingExpr| grouping.expr.evaluate(&input_batch)) - .map(|r| r.map(|columnar| columnar.into_array(input_batch.num_rows()))) - .collect::>() - .map_err(|err| err.context("agg: evaluating grouping arrays error"))?; - let grouping_rows: Vec = grouping_row_converter - .convert_columns(&grouping_arrays)? - .into_iter() - .map(|row| row.as_ref().into()) - .collect(); - - // compute input arrays - let input_arrays = agg_ctx - .create_input_arrays(&input_batch) - .map_err(|err| err.context("agg: evaluating input arrays error"))?; - let agg_buf_array = agg_ctx - .get_input_agg_buf_array(&input_batch) - .map_err(|err| err.context("agg: evaluating input agg-buf arrays error"))?; - - // update to current record - for (row_idx, grouping_row) in grouping_rows.into_iter().enumerate() { - // if group key differs, renew one and move the old record to staging - if Some(&grouping_row) != current_record.as_ref().map(|r| &r.0) { - let finished_record = - current_record.replace((grouping_row, agg_ctx.initial_agg_buf.clone())); - if let Some(record) = finished_record { - staging_records.push(record); - if staging_records.len() >= batch_size { - flush_staging!(); - } + let mut coalesced = context.coalesce_with_default_batch_size(input, &baseline_metrics)?; + + let output_schema = agg_ctx.output_schema.clone(); + context.output_with_sender("Agg", output_schema, move |sender| async move { + let mut staging_records = vec![]; + let mut current_record: Option<(SlimBytes, AggBuf)> = None; + let mut timer = elapsed_compute.timer(); + timer.stop(); + + macro_rules! flush_staging { + () => {{ + let batch = + agg_ctx.convert_records_to_batch(std::mem::take(&mut staging_records))?; + let num_rows = batch.num_rows(); + log::info!("aggregate exec (sorted) outputting one batch: num_rows={num_rows}"); + baseline_metrics.record_output(num_rows); + sender.send(Ok(batch), Some(&mut timer)).await; + }}; + } + while let Some(input_batch) = coalesced.next().await.transpose()? { + timer.restart(); + + // compute grouping rows + let grouping_rows = agg_ctx.create_grouping_rows(&input_batch)?; + + // compute input arrays + let input_arrays = agg_ctx.create_input_arrays(&input_batch)?; + let agg_buf_array = agg_ctx.get_input_agg_buf_array(&input_batch)?; + + // update to current record + for (row_idx, grouping_row) in grouping_rows.into_iter().enumerate() { + // if group key differs, renew one and move the old record to staging + if Some(grouping_row.as_ref()) != current_record.as_ref().map(|r| r.0.as_ref()) { + let finished_record = current_record.replace(( + grouping_row.as_ref().into(), + agg_ctx.initial_agg_buf.clone(), + )); + if let Some(record) = finished_record { + staging_records.push(record); + if staging_records.len() >= batch_size { + flush_staging!(); } } - let agg_buf = &mut current_record.as_mut().unwrap().1; - agg_ctx - .partial_update_input(agg_buf, &input_arrays, row_idx) - .map_err(|err| { - err.context("agg: executing partial_update_input() error") - })?; - agg_ctx - .partial_merge_input(agg_buf, agg_buf_array, row_idx) - .map_err(|err| err.context("agg: executing partial_merge_input() error"))?; } - timer.stop(); + let agg_buf = &mut current_record.as_mut().unwrap().1; + agg_ctx.partial_update_input(agg_buf, &input_arrays, row_idx)?; + agg_ctx.partial_merge_input(agg_buf, agg_buf_array, row_idx)?; } + timer.stop(); + } - if let Some(record) = current_record { - staging_records.push(record); - } - if !staging_records.is_empty() { - flush_staging!(); - } - Ok(()) - }, - ) + if let Some(record) = current_record { + staging_records.push(record); + } + if !staging_records.is_empty() { + flush_staging!(); + } + Ok(()) + }) } + #[cfg(test)] mod test { - use crate::agg::AggExecMode::HashAgg; - use crate::agg::AggMode::{Final, Partial}; - use crate::agg::{create_agg, AggExpr, AggFunction, GroupingExpr}; - use crate::agg_exec::AggExec; - use crate::common::memory_manager::MemManager; - use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_sorted_eq; - use datafusion::common::{Result, ScalarValue}; - use datafusion::physical_expr::expressions as phys_expr; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::{common, ExecutionPlan}; - use datafusion::prelude::SessionContext; use std::sync::Arc; + use arrow::{ + array::Int32Array, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_sorted_eq, + common::{Result, ScalarValue}, + physical_expr::{expressions as phys_expr, expressions::Column}, + physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + + use crate::{ + agg::{ + create_agg, + AggExecMode::HashAgg, + AggExpr, AggFunction, + AggMode::{Final, Partial}, + GroupingExpr, + }, + agg_exec::AggExec, + memmgr::MemManager, + }; + fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -660,6 +592,7 @@ mod test { }], aggs_agg_expr.clone(), 0, + false, input, )?; @@ -682,6 +615,7 @@ mod test { }) .collect::>()?, 0, + false, Arc::new(agg_exec_partial), )?; diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 047268e6..9377dd6f 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs @@ -12,35 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::sort_exec::SortExec; -use crate::sort_merge_join_exec::SortMergeJoinExec; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use blaze_jni_bridge::jni_call_static; -use datafusion::common::{DataFusionError, Result, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::logical_expr::JoinType; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::expressions::Column; -use datafusion::physical_plan::joins::utils::{ - build_join_schema, check_join_is_valid, JoinFilter, JoinOn, +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, + task::Poll, + time::Duration, }; -use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; -use datafusion::physical_plan::memory::MemoryStream; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use blaze_jni_bridge::{ + conf, + conf::{BooleanConf, IntConf}, +}; +use datafusion::{ + common::{DataFusionError, Result, Statistics}, + execution::context::TaskContext, + logical_expr::JoinType, + physical_expr::PhysicalSortExpr, + physical_plan::{ + expressions::Column, + joins::{ + utils::{build_join_schema, check_join_is_valid, JoinFilter, JoinOn}, + HashJoinExec, PartitionMode, + }, + memory::MemoryStream, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + }, }; -use futures::stream::once; -use futures::{StreamExt, TryStreamExt}; -use jni::sys::{jboolean, JNI_TRUE}; +use futures::{stream::once, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; -use std::task::Poll; -use std::time::Duration; + +use crate::{sort_exec::SortExec, sort_merge_join_exec::SortMergeJoinExec}; #[derive(Debug)] pub struct BroadcastJoinExec { @@ -178,12 +183,9 @@ async fn execute_broadcast_join( join_filter: Option, metrics: BaselineMetrics, ) -> Result { - let enabled_fallback_to_smj: bool = - jni_call_static!(BlazeConf.enableBhjFallbacksToSmj() -> jboolean)? == JNI_TRUE; - let bhj_num_rows_limit: usize = - jni_call_static!(BlazeConf.bhjFallbacksToSmjRowsThreshold() -> i32)? as usize; - let bhj_mem_size_limit: usize = - jni_call_static!(BlazeConf.bhjFallbacksToSmjMemThreshold() -> i32)? as usize; + let enabled_fallback_to_smj = conf::BHJ_FALLBACKS_TO_SMJ_ENABLE.value()?; + let bhj_num_rows_limit = conf::BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD.value()? as usize; + let bhj_mem_size_limit = conf::BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD.value()? as usize; // if broadcasted size is small enough, use hash join // otherwise use sort-merge join @@ -304,10 +306,13 @@ async fn execute_broadcast_join( let join_metrics = join.metrics().unwrap(); metrics.record_output(join_metrics.output_rows().unwrap_or(0)); metrics.elapsed_compute().add_duration(Duration::from_nanos( - [right_sorted_metrics.elapsed_compute(), join_metrics.elapsed_compute()] - .into_iter() - .flatten() - .sum::() as u64, + [ + right_sorted_metrics.elapsed_compute(), + join_metrics.elapsed_compute(), + ] + .into_iter() + .flatten() + .sum::() as u64, )); Poll::Ready(None) })); diff --git a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs index 4e195d2e..d58843d4 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs @@ -12,24 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec; +use std::{any::Any, fmt::Formatter, sync::Arc}; + use arrow::datatypes::SchemaRef; -use datafusion::common::{JoinType, Result, Statistics}; -use datafusion::execution::{SendableRecordBatchStream, TaskContext}; - -use datafusion::physical_expr::{Partitioning, PhysicalSortExpr}; -use datafusion::physical_plan::joins::utils::{build_join_schema, check_join_is_valid, JoinFilter}; -use datafusion::physical_plan::joins::NestedLoopJoinExec; -use datafusion::physical_plan::memory::MemoryExec; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; -use futures::stream::once; -use futures::{StreamExt, TryStreamExt}; +use datafusion::{ + common::{JoinType, Result, Statistics}, + execution::{SendableRecordBatchStream, TaskContext}, + physical_expr::{Partitioning, PhysicalSortExpr}, + physical_plan::{ + joins::{ + utils::{build_join_schema, check_join_is_valid, JoinFilter}, + NestedLoopJoinExec, + }, + memory::MemoryExec, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, + }, +}; +use futures::{stream::once, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; + +use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec; #[derive(Debug)] pub struct BroadcastNestedLoopJoinExec { diff --git a/native-engine/datafusion-ext-plans/src/common/batch_statisitcs.rs b/native-engine/datafusion-ext-plans/src/common/batch_statisitcs.rs index 9044ca6a..4e6aca56 100644 --- a/native-engine/datafusion-ext-plans/src/common/batch_statisitcs.rs +++ b/native-engine/datafusion-ext-plans/src/common/batch_statisitcs.rs @@ -13,11 +13,15 @@ // limitations under the License. use arrow::record_batch::RecordBatch; -use blaze_jni_bridge::{is_jni_bridge_inited, jni_call_static}; -use datafusion::common::Result; -use datafusion::execution::SendableRecordBatchStream; -use datafusion::physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use blaze_jni_bridge::{conf, conf::BooleanConf, is_jni_bridge_inited}; +use datafusion::{ + common::Result, + execution::SendableRecordBatchStream, + physical_plan::{ + metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}, + stream::RecordBatchStreamAdapter, + }, +}; use futures::StreamExt; #[derive(Clone)] @@ -34,8 +38,7 @@ impl InputBatchStatistics { metrics_set: &ExecutionPlanMetricsSet, partition: usize, ) -> Result> { - let enabled = is_jni_bridge_inited() - && jni_call_static!(BlazeConf.enableInputBatchStatistics() -> bool)?; + let enabled = is_jni_bridge_inited() && conf::INPUT_BATCH_STATISTICS_ENABLE.value()?; Ok(enabled.then_some(Self::from_metrics_set(metrics_set, partition))) } diff --git a/native-engine/datafusion-ext-plans/src/common/cached_exprs_evaluator.rs b/native-engine/datafusion-ext-plans/src/common/cached_exprs_evaluator.rs index 5efdd538..4dc18613 100644 --- a/native-engine/datafusion-ext-plans/src/common/cached_exprs_evaluator.rs +++ b/native-engine/datafusion-ext-plans/src/common/cached_exprs_evaluator.rs @@ -12,28 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::{Array, ArrayRef, BooleanArray}; -use arrow::compute::{filter, filter_record_batch, prep_null_mask_filter}; -use arrow::datatypes::{DataType, Schema, SchemaRef}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion::common::cast::as_boolean_array; -use datafusion::common::tree_node::{Transformed, TreeNode}; -use datafusion::common::{Result, ScalarValue}; -use datafusion::physical_expr::expressions::{ - CaseExpr, Column, Literal, NoOp, SCAndExpr, SCOrExpr, +use std::{ + any::Any, + cell::RefCell, + collections::{HashMap, HashSet}, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + rc::Rc, + sync::Arc, +}; + +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + compute::{filter, filter_record_batch, prep_null_mask_filter}, + datatypes::{DataType, Schema, SchemaRef}, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use datafusion::{ + common::{ + cast::as_boolean_array, + tree_node::{Transformed, TreeNode}, + Result, ScalarValue, + }, + physical_expr::{ + expressions::{CaseExpr, Column, Literal, NoOp, SCAndExpr, SCOrExpr}, + scatter, PhysicalExpr, PhysicalExprRef, + }, + physical_plan::ColumnarValue, }; -use datafusion::physical_expr::{scatter, PhysicalExpr, PhysicalExprRef}; -use datafusion::physical_plan::ColumnarValue; use datafusion_ext_commons::uda::UserDefinedArray; use itertools::Itertools; use parking_lot::Mutex; -use std::any::Any; -use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use std::fmt::{Debug, Display, Formatter}; -use std::hash::{Hash, Hasher}; -use std::rc::Rc; -use std::sync::Arc; pub struct CachedExprsEvaluator { transformed_projection_exprs: Vec, diff --git a/native-engine/datafusion-ext-plans/src/common/column_pruning.rs b/native-engine/datafusion-ext-plans/src/common/column_pruning.rs index 1a5df971..c8c1a3a0 100644 --- a/native-engine/datafusion-ext-plans/src/common/column_pruning.rs +++ b/native-engine/datafusion-ext-plans/src/common/column_pruning.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + use arrow::record_batch::RecordBatch; -use datafusion::common::{ - tree_node::{Transformed, TreeNode}, - Result, +use datafusion::{ + common::{ + tree_node::{Transformed, TreeNode}, + Result, + }, + execution::{SendableRecordBatchStream, TaskContext}, + physical_expr::{expressions::Column, utils::collect_columns, PhysicalExprRef}, + physical_plan::{stream::RecordBatchStreamAdapter, ExecutionPlan}, }; -use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::physical_expr::expressions::Column; -use datafusion::physical_expr::utils::collect_columns; -use datafusion::physical_expr::PhysicalExprRef; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::ExecutionPlan; use futures::StreamExt; use itertools::Itertools; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; pub trait ExecuteWithColumnPruning { fn execute_projected( diff --git a/native-engine/datafusion-ext-plans/src/common/mod.rs b/native-engine/datafusion-ext-plans/src/common/mod.rs index b09154b9..54134c66 100644 --- a/native-engine/datafusion-ext-plans/src/common/mod.rs +++ b/native-engine/datafusion-ext-plans/src/common/mod.rs @@ -12,21 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::array::{ArrayRef, PrimitiveArray, UInt32Array}; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow::{ + array::{ArrayRef, PrimitiveArray, UInt32Array}, + datatypes::SchemaRef, + error::Result as ArrowResult, + record_batch::{RecordBatch, RecordBatchOptions}, +}; use datafusion::common::Result; pub mod batch_statisitcs; -pub mod bytes_arena; pub mod cached_exprs_evaluator; pub mod column_pruning; -pub mod memory_manager; -pub mod onheap_spill; pub mod output; -pub mod rdxsort; -pub mod slim_bytes; pub struct BatchTaker<'a>(pub &'a RecordBatch); diff --git a/native-engine/datafusion-ext-plans/src/common/output.rs b/native-engine/datafusion-ext-plans/src/common/output.rs index b9f3144b..95f3fdc0 100644 --- a/native-engine/datafusion-ext-plans/src/common/output.rs +++ b/native-engine/datafusion-ext-plans/src/common/output.rs @@ -12,27 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::future::Future; -use std::io::{Cursor, Write}; -use std::panic::AssertUnwindSafe; -use std::sync::{Arc, Weak}; - -use crate::common::memory_manager::{MemConsumer, MemManager}; -use crate::common::onheap_spill::try_new_spill; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use std::{ + future::Future, + io::{Cursor, Write}, + panic::AssertUnwindSafe, + sync::{Arc, Weak}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use blaze_jni_bridge::is_task_running; -use datafusion::common::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::ScopedTimerGuard; -use datafusion::physical_plan::stream::RecordBatchReceiverStream; -use datafusion::physical_plan::SendableRecordBatchStream; +use datafusion::{ + common::{DataFusionError, Result}, + execution::context::TaskContext, + physical_plan::{ + metrics::ScopedTimerGuard, stream::RecordBatchReceiverStream, SendableRecordBatchStream, + }, +}; use datafusion_ext_commons::io::{read_one_batch, write_one_batch}; use futures::{FutureExt, StreamExt, TryFutureExt}; use once_cell::sync::OnceCell; use parking_lot::Mutex; use tokio::sync::mpsc::Sender; +use crate::memmgr::{onheap_spill::try_new_spill, MemConsumer, MemManager}; + fn working_senders() -> &'static Mutex>> { static WORKING_SENDERS: OnceCell>>> = OnceCell::new(); WORKING_SENDERS.get_or_init(|| Mutex::default()) @@ -89,74 +92,87 @@ impl WrappedRecordBatchSender { } } -pub fn output_with_sender> + Send>( - desc: &'static str, - task_context: Arc, - output_schema: SchemaRef, - output: impl FnOnce(Arc) -> Fut + Send + 'static, -) -> Result { - let mut stream_builder = RecordBatchReceiverStream::builder(output_schema.clone(), 1); - let sender = stream_builder.tx().clone(); - let err_sender = sender.clone(); - - stream_builder.spawn(async move { - let wrapped = WrappedRecordBatchSender::new(task_context, sender); - let result = AssertUnwindSafe(async move { - let task_running = is_task_running(); - if !task_running { - panic!( - "output_with_sender[{}] canceled due to task finished/killed", - desc - ); - } - output(wrapped) - .unwrap_or_else(|err| { +pub trait TaskOutputter { + fn output_with_sender> + Send>( + &self, + desc: &'static str, + output_schema: SchemaRef, + output: impl FnOnce(Arc) -> Fut + Send + 'static, + ) -> Result; + + fn output_bufferable_with_spill( + &self, + mem_consumer: Arc, + stream: SendableRecordBatchStream, + ) -> Result; +} + +impl TaskOutputter for Arc { + fn output_with_sender> + Send>( + &self, + desc: &'static str, + output_schema: SchemaRef, + output: impl FnOnce(Arc) -> Fut + Send + 'static, + ) -> Result { + let mut stream_builder = RecordBatchReceiverStream::builder(output_schema, 1); + let sender = stream_builder.tx().clone(); + let err_sender = sender.clone(); + let wrapped_sender = WrappedRecordBatchSender::new(self.clone(), sender); + + stream_builder.spawn(async move { + let result = AssertUnwindSafe(async move { + let task_running = is_task_running(); + if !task_running { panic!( - "output_with_sender[{}]: output() returns error: {}", - desc, err + "output_with_sender[{}] canceled due to task finished/killed", + desc ); - }) - .await - }) - .catch_unwind() - .await - .map(|_| Ok(())) - .unwrap_or_else(|err| { - let panic_message = panic_message::get_panic_message(&err).unwrap_or("unknown error"); - Err(DataFusionError::Execution(panic_message.to_owned())) - }); - - if let Err(err) = result { - let err_message = err.to_string(); - let _ = err_sender.send(Err(err)).await; - - // panic current spawn - let task_running = is_task_running(); - if !task_running { - panic!( - "output_with_sender[{}] canceled due to task finished/killed", - desc - ); - } else { - panic!("output_with_sender[{}] error: {}", desc, err_message); + } + output(wrapped_sender) + .unwrap_or_else(|err| { + panic!( + "output_with_sender[{}]: output() returns error: {}", + desc, err + ); + }) + .await + }) + .catch_unwind() + .await + .map(|_| Ok(())) + .unwrap_or_else(|err| { + let panic_message = + panic_message::get_panic_message(&err).unwrap_or("unknown error"); + Err(DataFusionError::Execution(panic_message.to_owned())) + }); + + if let Err(err) = result { + let err_message = err.to_string(); + let _ = err_sender.send(Err(err)).await; + + // panic current spawn + let task_running = is_task_running(); + if !task_running { + panic!( + "output_with_sender[{}] canceled due to task finished/killed", + desc + ); + } else { + panic!("output_with_sender[{}] error: {}", desc, err_message); + } } - } - }); - Ok(stream_builder.build()) -} + }); + Ok(stream_builder.build()) + } -pub fn output_bufferable_with_spill( - mem_consumer: Arc, - task_context: Arc, - mut stream: SendableRecordBatchStream, -) -> Result { - let output_schema = stream.schema(); - - output_with_sender( - "OutputBufferableWithSpill", - task_context, - output_schema.clone(), - move |sender| async move { + fn output_bufferable_with_spill( + &self, + mem_consumer: Arc, + mut stream: SendableRecordBatchStream, + ) -> Result { + let schema = stream.schema(); + let desc = "OutputBufferableWithSpill"; + self.output_with_sender(desc, schema.clone(), move |sender| async move { while let Some(batch) = { // if consumer is holding too much memory, we will create a spill // to receive all of its outputs and release all memory. @@ -177,7 +193,7 @@ pub fn output_bufferable_with_spill( // read all batches from spill and output let mut spill_reader = spill.get_buf_reader(); while let Some(batch) = - read_one_batch(&mut spill_reader, Some(output_schema.clone()), true)? + read_one_batch(&mut spill_reader, Some(schema.clone()), true)? { sender.send(Ok(batch), None).await; } @@ -188,6 +204,6 @@ pub fn output_bufferable_with_spill( sender.send(Ok(batch), None).await; } Ok(()) - }, - ) + }) + } } diff --git a/native-engine/datafusion-ext-plans/src/debug_exec.rs b/native-engine/datafusion-ext-plans/src/debug_exec.rs index 19924e92..729337a6 100644 --- a/native-engine/datafusion-ext-plans/src/debug_exec.rs +++ b/native-engine/datafusion-ext-plans/src/debug_exec.rs @@ -12,32 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::datatypes::SchemaRef; -use async_trait::async_trait; - -use arrow::record_batch::RecordBatch; -use arrow::util::pretty::pretty_format_batches; -use datafusion::error::DataFusionError; -use datafusion::error::Result; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; - -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; - -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch, util::pretty::pretty_format_batches}; +use async_trait::async_trait; +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, + }, +}; use futures::{Stream, StreamExt}; -use std::any::Any; -use std::fmt::Formatter; - -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - #[derive(Debug)] pub struct DebugExec { input: Arc, diff --git a/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs b/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs index 4daf5652..adb5544e 100644 --- a/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs +++ b/native-engine/datafusion-ext-plans/src/empty_partitions_exec.rs @@ -12,24 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::Any; -use std::fmt::Formatter; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use async_trait::async_trait; -use datafusion::error::DataFusionError; -use datafusion::error::Result; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::MetricsSet; -use datafusion::physical_plan::Partitioning::UnknownPartitioning; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::MetricsSet, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + Partitioning::UnknownPartitioning, RecordBatchStream, SendableRecordBatchStream, + Statistics, + }, }; use futures::Stream; diff --git a/native-engine/datafusion-ext-plans/src/expand_exec.rs b/native-engine/datafusion-ext-plans/src/expand_exec.rs index 438f0925..4b242ef0 100644 --- a/native-engine/datafusion-ext-plans/src/expand_exec.rs +++ b/native-engine/datafusion-ext-plans/src/expand_exec.rs @@ -12,24 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::common::{DataFusionError, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::Partitioning::UnknownPartitioning; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + task::{ready, Context, Poll}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion::{ + common::{DataFusionError, Result, Statistics}, + execution::context::TaskContext, + physical_expr::{PhysicalExpr, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + Partitioning::UnknownPartitioning, + RecordBatchStream, SendableRecordBatchStream, + }, }; use futures::{Stream, StreamExt}; -use std::any::Any; -use std::fmt::Formatter; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{ready, Context, Poll}; #[derive(Debug, Clone)] pub struct ExpandExec { @@ -195,21 +198,25 @@ impl Stream for ExpandStream { #[cfg(test)] mod test { - use crate::common::memory_manager::MemManager; - use crate::expand_exec::ExpandExec; - use arrow::array::{BooleanArray, Float32Array, Int32Array, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::common::{Result, ScalarValue}; - use datafusion::logical_expr::Operator; - use datafusion::physical_expr::expressions::{binary, col, lit}; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::{common, ExecutionPlan}; - use datafusion::prelude::SessionContext; use std::sync::Arc; - //build i32 table + use arrow::{ + array::{BooleanArray, Float32Array, Int32Array, StringArray}, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_eq, + common::{Result, ScalarValue}, + logical_expr::Operator, + physical_expr::expressions::{binary, col, lit}, + physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + + use crate::{expand_exec::ExpandExec, memmgr::MemManager}; + + // build i32 table fn build_table_i32(a: (&str, &Vec)) -> RecordBatch { let schema = Schema::new(vec![Field::new(a.0, DataType::Int32, false)]); @@ -226,7 +233,7 @@ mod test { Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } - //build f32 table + // build f32 table fn build_table_f32(a: (&str, &Vec)) -> RecordBatch { let schema = Schema::new(vec![Field::new(a.0, DataType::Float32, false)]); @@ -243,7 +250,7 @@ mod test { Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } - //build str table + // build str table fn build_table_str(a: (&str, &Vec)) -> RecordBatch { let schema = Schema::new(vec![Field::new(a.0, DataType::Utf8, false)]); @@ -260,7 +267,7 @@ mod test { Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } - //build boolean table + // build boolean table fn build_table_bool(a: (&str, &Vec)) -> RecordBatch { let schema = Schema::new(vec![Field::new(a.0, DataType::Boolean, false)]); @@ -415,7 +422,12 @@ mod test { let input = build_table_string(( "a", - &vec!["hello".to_string(), ",".to_string(), "rust".to_string(), "!".to_string()], + &vec![ + "hello".to_string(), + ",".to_string(), + "rust".to_string(), + "!".to_string(), + ], )); let schema = Schema::new(vec![Field::new("test_str", DataType::Utf8, false)]); diff --git a/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs index befc6322..6638a2e0 100644 --- a/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ffi_reader_exec.rs @@ -12,24 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + use arrow::datatypes::SchemaRef; use blaze_jni_bridge::{jni_call, jni_call_static, jni_new_global_ref, jni_new_string}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, -}; -use datafusion::physical_plan::Partitioning::UnknownPartitioning; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, - Statistics, +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + Partitioning::UnknownPartitioning, + SendableRecordBatchStream, Statistics, + }, }; use datafusion_ext_commons::streams::ffi_stream::FFIReaderStream; use jni::objects::JObject; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; pub struct FFIReaderExec { num_partitions: usize, diff --git a/native-engine/datafusion-ext-plans/src/filter_exec.rs b/native-engine/datafusion-ext-plans/src/filter_exec.rs index e3527e83..b3f51c88 100644 --- a/native-engine/datafusion-ext-plans/src/filter_exec.rs +++ b/native-engine/datafusion-ext-plans/src/filter_exec.rs @@ -12,29 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::batch_statisitcs::{stat_input, InputBatchStatistics}; -use crate::common::cached_exprs_evaluator::CachedExprsEvaluator; -use crate::common::column_pruning::ExecuteWithColumnPruning; -use crate::common::output::output_with_sender; -use crate::project_exec::ProjectExec; +use std::{any::Any, fmt::Formatter, sync::Arc}; + use arrow::datatypes::{DataType, SchemaRef}; -use datafusion::common::Statistics; -use datafusion::common::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::expressions::Column; -use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortExpr}; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +use datafusion::{ + common::{DataFusionError, Result, Statistics}, + execution::context::TaskContext, + physical_expr::{expressions::Column, PhysicalExprRef, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + }, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use futures::stream::once; -use futures::{StreamExt, TryStreamExt}; +use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; +use futures::{stream::once, StreamExt, TryStreamExt}; use itertools::Itertools; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; + +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + cached_exprs_evaluator::CachedExprsEvaluator, + column_pruning::ExecuteWithColumnPruning, + output::TaskOutputter, + }, + project_exec::ProjectExec, +}; #[derive(Debug, Clone)] pub struct FilterExec { @@ -121,20 +124,23 @@ impl ExecutionPlan for FilterExec { partition: usize, context: Arc, ) -> Result { - let batch_size = context.session_config().batch_size(); let predicates = self.predicates.clone(); let metrics = BaselineMetrics::new(&self.metrics, partition); - let elapsed_compute = metrics.elapsed_compute().clone(); - let input = stat_input( InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?, self.input.execute(partition, context.clone())?, )?; let filtered = Box::pin(RecordBatchStreamAdapter::new( self.schema(), - once(execute_filter(input, context, predicates, metrics)).try_flatten(), + once(execute_filter( + input, + context.clone(), + predicates, + metrics.clone(), + )) + .try_flatten(), )); - let coalesced = Box::pin(CoalesceStream::new(filtered, batch_size, elapsed_compute)); + let coalesced = context.coalesce_with_default_batch_size(filtered, &metrics)?; Ok(coalesced) } @@ -180,18 +186,13 @@ async fn execute_filter( ) -> Result { let cached_exprs_evaluator = CachedExprsEvaluator::try_new(predicates, vec![])?; - output_with_sender( - "Filter", - context, - input.schema(), - move |sender| async move { - while let Some(batch) = input.next().await.transpose()? { - let mut timer = metrics.elapsed_compute().timer(); - let filtered_batch = cached_exprs_evaluator.filter(&batch)?; - metrics.record_output(filtered_batch.num_rows()); - sender.send(Ok(filtered_batch), Some(&mut timer)).await; - } - Ok(()) - }, - ) + context.output_with_sender("Filter", input.schema(), move |sender| async move { + while let Some(batch) = input.next().await.transpose()? { + let mut timer = metrics.elapsed_compute().timer(); + let filtered_batch = cached_exprs_evaluator.filter(&batch)?; + metrics.record_output(filtered_batch.num_rows()); + sender.send(Ok(filtered_batch), Some(&mut timer)).await; + } + Ok(()) + }) } diff --git a/native-engine/datafusion-ext-plans/src/generate/explode.rs b/native-engine/datafusion-ext-plans/src/generate/explode.rs index 7600b975..862c41f0 100644 --- a/native-engine/datafusion-ext-plans/src/generate/explode.rs +++ b/native-engine/datafusion-ext-plans/src/generate/explode.rs @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::generate::{GeneratedRows, Generator}; -use arrow::array::*; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::physical_expr::PhysicalExpr; -use itertools::Itertools; use std::sync::Arc; +use arrow::{array::*, record_batch::RecordBatch}; +use datafusion::{common::Result, physical_expr::PhysicalExpr}; +use itertools::Itertools; + +use crate::generate::{GeneratedRows, Generator}; + #[derive(Debug)] pub struct ExplodeArray { child: Arc, diff --git a/native-engine/datafusion-ext-plans/src/generate/mod.rs b/native-engine/datafusion-ext-plans/src/generate/mod.rs index 96077f01..d7771b05 100644 --- a/native-engine/datafusion-ext-plans/src/generate/mod.rs +++ b/native-engine/datafusion-ext-plans/src/generate/mod.rs @@ -14,17 +14,16 @@ pub mod explode; -use crate::generate::explode::{ExplodeArray, ExplodeMap}; +use std::{fmt::Debug, sync::Arc}; -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::{ + array::{ArrayRef, UInt32Array}, + datatypes::{DataType, SchemaRef}, + record_batch::RecordBatch, +}; +use datafusion::{common::Result, error::DataFusionError, physical_plan::PhysicalExpr}; -use arrow::array::{ArrayRef, UInt32Array}; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::physical_plan::PhysicalExpr; -use std::fmt::Debug; -use std::sync::Arc; +use crate::generate::explode::{ExplodeArray, ExplodeMap}; pub trait Generator: Debug + Send + Sync { fn exprs(&self) -> Vec>; diff --git a/native-engine/datafusion-ext-plans/src/generate_exec.rs b/native-engine/datafusion-ext-plans/src/generate_exec.rs index 30b04a03..de899f3d 100644 --- a/native-engine/datafusion-ext-plans/src/generate_exec.rs +++ b/native-engine/datafusion-ext-plans/src/generate_exec.rs @@ -12,29 +12,39 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::batch_statisitcs::{stat_input, InputBatchStatistics}; -use crate::common::output::output_with_sender; -use crate::generate::Generator; -use arrow::array::UInt32Builder; -use arrow::datatypes::{Field, Schema, SchemaRef}; -use arrow::error::ArrowError; -use arrow::record_batch::RecordBatch; -use datafusion::common::{Result, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::expressions::Column; -use datafusion::physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use futures::stream::once; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; + +use arrow::{ + array::UInt32Builder, + datatypes::{Field, Schema, SchemaRef}, + error::ArrowError, + record_batch::RecordBatch, +}; +use datafusion::{ + common::{Result, Statistics}, + execution::context::TaskContext, + physical_expr::{expressions::Column, PhysicalExpr, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + }, +}; +use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; +use futures::{stream::once, StreamExt, TryFutureExt, TryStreamExt}; use num::integer::Roots; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::sync::Arc; + +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + output::TaskOutputter, + }, + generate::Generator, +}; #[derive(Debug)] pub struct GenerateExec { @@ -143,7 +153,6 @@ impl ExecutionPlan for GenerateExec { partition: usize, context: Arc, ) -> Result { - let batch_size = context.session_config().batch_size(); let output_schema = self.output_schema.clone(); let generator = self.generator.clone(); let outer = self.outer; @@ -158,12 +167,11 @@ impl ExecutionPlan for GenerateExec { once( execute_generate( input, - context, + context.clone(), output_schema, generator, outer, child_output_cols, - batch_size, metrics, ) .map_err(ArrowError::from), @@ -172,11 +180,7 @@ impl ExecutionPlan for GenerateExec { )); let metrics = BaselineMetrics::new(&self.metrics, partition); - let output_coalesced = Box::pin(CoalesceStream::new( - output_stream, - batch_size, - metrics.elapsed_compute().clone(), - )); + let output_coalesced = context.coalesce_with_default_batch_size(output_stream, &metrics)?; Ok(output_coalesced) } @@ -196,12 +200,11 @@ async fn execute_generate( generator: Arc, outer: bool, child_output_cols: Vec, - batch_size: usize, metrics: BaselineMetrics, ) -> Result { - output_with_sender( + let batch_size = context.session_config().batch_size(); + context.output_with_sender( "Generate", - context, output_schema.clone(), move |sender| async move { while let Some(batch) = input_stream @@ -297,19 +300,22 @@ async fn execute_generate( #[cfg(test)] mod test { - use crate::generate::{create_generator, GenerateFunc}; - use crate::generate_exec::GenerateExec; - use arrow::array::*; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::common::Result; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::{common, ExecutionPlan}; - use datafusion::prelude::SessionContext; use std::sync::Arc; + use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; + use datafusion::{ + assert_batches_eq, + common::Result, + physical_expr::expressions::Column, + physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + + use crate::{ + generate::{create_generator, GenerateFunc}, + generate_exec::GenerateExec, + }; + #[tokio::test] async fn test_explode() -> Result<()> { let session_ctx = SessionContext::new(); diff --git a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs index 2af9ed2c..6c43df1b 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs @@ -12,28 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + use arrow::datatypes::SchemaRef; use async_trait::async_trait; use blaze_jni_bridge::{jni_call, jni_call_static, jni_new_global_ref, jni_new_string}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::expressions::PhysicalSortExpr; -use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion::physical_plan::metrics::MetricsSet; -use datafusion::physical_plan::metrics::{BaselineMetrics, MetricBuilder}; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::Partitioning; -use datafusion::physical_plan::Partitioning::UnknownPartitioning; -use datafusion::physical_plan::SendableRecordBatchStream; -use datafusion::physical_plan::Statistics; -use datafusion::physical_plan::{DisplayAs, DisplayFormatType}; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use datafusion_ext_commons::streams::ipc_stream::{IpcReadMode, IpcReaderStream}; +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + Partitioning::UnknownPartitioning, + SendableRecordBatchStream, Statistics, + }, +}; +use datafusion_ext_commons::streams::{ + coalesce_stream::CoalesceInput, + ipc_stream::{IpcReadMode, IpcReaderStream}, +}; use jni::objects::JObject; -use std::any::Any; -use std::fmt::Debug; -use std::fmt::Formatter; -use std::sync::Arc; #[derive(Debug, Clone)] pub struct IpcReaderExec { @@ -126,13 +129,10 @@ impl ExecutionPlan for IpcReaderExec { baseline_metrics, size_counter, )); - Ok(Box::pin(CoalesceStream::new( + Ok(context.coalesce_with_default_batch_size( ipc_stream, - context.session_config().batch_size(), - BaselineMetrics::new(&self.metrics, partition) - .elapsed_compute() - .clone(), - ))) + &BaselineMetrics::new(&self.metrics, partition), + )?) } fn metrics(&self) -> Option { diff --git a/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs index 4955e06e..cd669df5 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_writer_exec.rs @@ -12,35 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{any::Any, fmt::Formatter, io::Cursor, sync::Arc}; + use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError; -use arrow::record_batch::RecordBatch; use async_trait::async_trait; use blaze_jni_bridge::{ jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, jni_new_string, }; -use datafusion::error::DataFusionError; -use datafusion::error::Result; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::memory::MemoryStream; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, - Statistics, +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, + }, }; -use datafusion_ext_commons::concat_batches; -use datafusion_ext_commons::io::write_one_batch; - -use futures::StreamExt; -use futures::TryFutureExt; -use futures::TryStreamExt; +use datafusion_ext_commons::{io::write_one_batch, streams::coalesce_stream::CoalesceInput}; +use futures::{stream::once, StreamExt, TryStreamExt}; use jni::objects::{GlobalRef, JObject}; -use std::any::Any; -use std::fmt::Formatter; -use std::io::Cursor; -use std::sync::Arc; + +use crate::common::output::TaskOutputter; #[derive(Debug)] pub struct IpcWriterExec { @@ -114,18 +108,16 @@ impl ExecutionPlan for IpcWriterExec { )?; let ipc_consumer = jni_new_global_ref!(ipc_consumer_local.as_obj())?; let input = self.input.execute(partition, context.clone())?; + let input_coalesced = context.coalesce_with_default_batch_size(input, &baseline_metrics)?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::once( - write_ipc( - input, - context.session_config().batch_size(), - ipc_consumer, - baseline_metrics, - ) - .map_err(|e| ArrowError::ExternalError(Box::new(e))), - ) + once(write_ipc( + input_coalesced, + context, + ipc_consumer, + baseline_metrics, + )) .try_flatten(), ))) } @@ -141,55 +133,26 @@ impl ExecutionPlan for IpcWriterExec { pub async fn write_ipc( mut input: SendableRecordBatchStream, - batch_size: usize, + context: Arc, ipc_consumer: GlobalRef, metrics: BaselineMetrics, ) -> Result { let schema = input.schema(); - let mut batches: Vec = vec![]; - let mut num_rows = 0; - - macro_rules! flush_batches { - () => {{ + context.output_with_sender("IpcWrite", schema.clone(), move |_sender| async move { + while let Some(batch) = input.next().await.transpose()? { let timer = metrics.elapsed_compute().timer(); - let batch = concat_batches(&schema, &batches, num_rows)?; - metrics.record_output(num_rows); - batches.clear(); - num_rows = 0; + let num_rows = batch.num_rows(); let mut buffer = vec![]; - write_one_batch( - &batch, - &mut Cursor::new(&mut buffer), - true, - None, - )?; + write_one_batch(&batch, &mut Cursor::new(&mut buffer), true, None)?; drop(timer); + metrics.record_output(num_rows); let buf = jni_new_direct_byte_buffer!(&buffer)?; let _consumed = jni_call!( ScalaFunction1(ipc_consumer.as_obj()).apply(buf.as_obj()) -> JObject )?; - }} - } - - while let Some(batch) = input.next().await { - let batch = batch?; - - if batch.num_rows() == 0 { - continue; - } - if num_rows + batch.num_rows() > batch_size { - flush_batches!(); } - num_rows += batch.num_rows(); - batches.push(batch); - } - if num_rows > 0 { - flush_batches!(); - } - assert_eq!(num_rows, 0); - - // ipc writer always has empty output - Ok(Box::pin(MemoryStream::try_new(vec![], schema, None)?)) + Ok(()) + }) } diff --git a/native-engine/datafusion-ext-plans/src/lib.rs b/native-engine/datafusion-ext-plans/src/lib.rs index 0d512d54..2e761f58 100644 --- a/native-engine/datafusion-ext-plans/src/lib.rs +++ b/native-engine/datafusion-ext-plans/src/lib.rs @@ -13,7 +13,6 @@ // limitations under the License. #![feature(get_mut_unchecked)] -#![feature(slice_swap_unchecked)] pub mod agg; pub mod agg_exec; @@ -30,6 +29,7 @@ pub mod generate_exec; pub mod ipc_reader_exec; pub mod ipc_writer_exec; pub mod limit_exec; +pub mod memmgr; pub mod parquet_exec; pub mod parquet_sink_exec; pub mod project_exec; diff --git a/native-engine/datafusion-ext-plans/src/limit_exec.rs b/native-engine/datafusion-ext-plans/src/limit_exec.rs index 50e1f8cb..8dc8048f 100644 --- a/native-engine/datafusion-ext-plans/src/limit_exec.rs +++ b/native-engine/datafusion-ext-plans/src/limit_exec.rs @@ -1,19 +1,23 @@ -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, +use std::{ + any::Any, + fmt::{Debug, Formatter}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion::{ + common::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, + }, }; use futures::{Stream, StreamExt}; -use std::any::Any; -use std::fmt::{Debug, Formatter}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; #[derive(Debug)] pub struct LimitExec { @@ -133,18 +137,22 @@ impl Stream for LimitStream { #[cfg(test)] mod test { - use crate::common::memory_manager::MemManager; - use crate::limit_exec::LimitExec; - use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::common::Result; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::{common, ExecutionPlan}; - use datafusion::prelude::SessionContext; use std::sync::Arc; + use arrow::{ + array::Int32Array, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_eq, + common::Result, + physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + + use crate::{limit_exec::LimitExec, memmgr::MemManager}; + fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), diff --git a/native-engine/datafusion-ext-plans/src/common/memory_manager.rs b/native-engine/datafusion-ext-plans/src/memmgr/mod.rs similarity index 99% rename from native-engine/datafusion-ext-plans/src/common/memory_manager.rs rename to native-engine/datafusion-ext-plans/src/memmgr/mod.rs index 9cb01d0d..654fc1c5 100644 --- a/native-engine/datafusion-ext-plans/src/common/memory_manager.rs +++ b/native-engine/datafusion-ext-plans/src/memmgr/mod.rs @@ -12,13 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod onheap_spill; + +use std::{ + sync::{Arc, Weak}, + time::Duration, +}; + use async_trait::async_trait; use bytesize::ByteSize; use datafusion::common::Result; use once_cell::sync::OnceCell; use parking_lot::{Condvar, Mutex}; -use std::sync::{Arc, Weak}; -use std::time::Duration; static MEM_MANAGER: OnceCell> = OnceCell::new(); diff --git a/native-engine/datafusion-ext-plans/src/common/onheap_spill.rs b/native-engine/datafusion-ext-plans/src/memmgr/onheap_spill.rs similarity index 94% rename from native-engine/datafusion-ext-plans/src/common/onheap_spill.rs rename to native-engine/datafusion-ext-plans/src/memmgr/onheap_spill.rs index dbc22528..0fdcf88a 100644 --- a/native-engine/datafusion-ext-plans/src/common/onheap_spill.rs +++ b/native-engine/datafusion-ext-plans/src/memmgr/onheap_spill.rs @@ -12,16 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + fs::File, + io::{BufReader, BufWriter, Read, Seek, Write}, + sync::Arc, +}; + use blaze_jni_bridge::{ is_jni_bridge_inited, jni_call, jni_call_static, jni_new_direct_byte_buffer, jni_new_global_ref, }; -use datafusion::common::Result; -use datafusion::parquet::file::reader::Length; -use jni::objects::GlobalRef; -use jni::sys::{jboolean, jlong, JNI_TRUE}; -use std::fs::File; -use std::io::{BufReader, BufWriter, Read, Seek, Write}; -use std::sync::Arc; +use datafusion::{common::Result, parquet::file::reader::Length}; +use jni::{ + objects::GlobalRef, + sys::{jboolean, jlong, JNI_TRUE}, +}; pub trait Spill: Send + Sync { fn complete(&self) -> Result<()>; diff --git a/native-engine/datafusion-ext-plans/src/parquet_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_exec.rs index 489fea9a..f70a5637 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_exec.rs @@ -17,50 +17,49 @@ //! Execution plan for reading Parquet files -use fmt::Debug; -use std::any::Any; -use std::fmt; -use std::fmt::Formatter; -use std::ops::Range; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, SchemaRef}; -use datafusion::common::DataFusionError; -use datafusion::datasource::physical_plan::parquet::page_filter::PagePruningPredicate; -use datafusion::datasource::physical_plan::parquet::ParquetOpener; -use datafusion::datasource::physical_plan::{ - FileMeta, FileScanConfig, FileStream, OnError, ParquetFileMetrics, ParquetFileReaderFactory, +use std::{any::Any, fmt, fmt::Formatter, ops::Range, sync::Arc}; + +use arrow::{ + array::ArrayRef, + datatypes::{DataType, SchemaRef}, +}; +use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; +use blaze_jni_bridge::{ + conf, conf::BooleanConf, jni_call_static, jni_new_global_ref, jni_new_string, }; -use datafusion::parquet::arrow::async_reader::{fetch_parquet_metadata, AsyncFileReader}; -use datafusion::parquet::errors::ParquetError; -use datafusion::parquet::file::metadata::ParquetMetaData; -use datafusion::physical_optimizer::pruning::PruningPredicate; -use datafusion::physical_plan::metrics::{BaselineMetrics, MetricValue, Time}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{DisplayAs, Metric, PhysicalExpr, RecordBatchStream}; +use bytes::Bytes; use datafusion::{ + common::DataFusionError, + datasource::physical_plan::{ + parquet::{page_filter::PagePruningPredicate, ParquetOpener}, + FileMeta, FileScanConfig, FileStream, OnError, ParquetFileMetrics, + ParquetFileReaderFactory, + }, error::Result, execution::context::TaskContext, + parquet::{ + arrow::async_reader::{fetch_parquet_metadata, AsyncFileReader}, + errors::ParquetError, + file::metadata::ParquetMetaData, + }, + physical_optimizer::pruning::PruningPredicate, physical_plan::{ expressions::PhysicalSortExpr, - metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, + metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet, Time, + }, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Metric, Partitioning, PhysicalExpr, + RecordBatchStream, SendableRecordBatchStream, Statistics, }, }; -use futures::future::BoxFuture; -use futures::stream::once; -use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; -use object_store::ObjectMeta; - -use base64::prelude::BASE64_URL_SAFE_NO_PAD; -use base64::Engine; -use blaze_jni_bridge::{jni_call_static, jni_new_global_ref, jni_new_string}; -use bytes::Bytes; use datafusion_ext_commons::hadoop_fs::{FsDataInputStream, FsProvider}; +use fmt::Debug; +use futures::{future::BoxFuture, stream::once, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +use object_store::ObjectMeta; use once_cell::sync::OnceCell; -use crate::common::output::output_with_sender; +use crate::common::output::TaskOutputter; #[no_mangle] fn schema_adapter_cast_column( @@ -85,7 +84,8 @@ pub struct ParquetExec { } impl ParquetExec { - /// Create a new Parquet reader execution plan provided file list and schema. + /// Create a new Parquet reader execution plan provided file list and + /// schema. pub fn new( base_config: FileScanConfig, fs_resource_id: String, @@ -203,7 +203,7 @@ impl ExecutionPlan for ParquetExec { context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition_index); - let timer = baseline_metrics.elapsed_compute().timer(); + let _timer = baseline_metrics.elapsed_compute().timer(); let io_time = Time::default(); let io_time_metric = Arc::new(Metric::new( @@ -241,23 +241,20 @@ impl ExecutionPlan for ParquetExec { reorder_filters: false, enable_page_index: false, }; - drop(timer); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition_index); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); let mut file_stream = FileStream::new(&self.base_config, partition_index, opener, &self.metrics)?; - if jni_call_static!(BlazeConf.ignoreCorruptedFiles() -> bool)? { + if conf::IGNORE_CORRUPTED_FILES.value()? { file_stream = file_stream.with_on_error(OnError::Skip); } let mut stream = Box::pin(file_stream); - Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), once(async move { - output_with_sender( + context.output_with_sender( "ParquetScan", - context, stream.schema(), move |sender| async move { let mut timer = elapsed_compute.timer(); diff --git a/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs index 7f251cef..b90e309d 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_sink_exec.rs @@ -15,35 +15,34 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use std::{any::Any, fmt::Formatter, io::Write, sync::Arc}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use blaze_jni_bridge::{jni_call_static, jni_new_global_ref, jni_new_string}; -use datafusion::common::{DataFusionError, Result, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::parquet::arrow::{parquet_to_arrow_schema, ArrowWriter}; -use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; -use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; -use datafusion::parquet::schema::parser::parse_message_type; -use datafusion::parquet::schema::types::SchemaDescriptor; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricValue, MetricsSet, Time, +use datafusion::{ + common::{DataFusionError, Result, Statistics}, + execution::context::TaskContext, + parquet::{ + arrow::{parquet_to_arrow_schema, ArrowWriter}, + basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}, + file::properties::{WriterProperties, WriterVersion}, + schema::{parser::parse_message_type, types::SchemaDescriptor}, + }, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricValue, MetricsSet, Time}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, Metric, Partitioning, + SendableRecordBatchStream, + }, }; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, EmptyRecordBatchStream, ExecutionPlan, Metric, Partitioning, - SendableRecordBatchStream, +use datafusion_ext_commons::{ + cast::cast, + hadoop_fs::{FsDataOutputStream, FsProvider}, }; -use datafusion_ext_commons::cast::cast; -use datafusion_ext_commons::hadoop_fs::{FsDataOutputStream, FsProvider}; -use futures::stream::once; -use futures::{StreamExt, TryStreamExt}; +use futures::{stream::once, StreamExt, TryStreamExt}; use once_cell::sync::OnceCell; use parking_lot::Mutex; -use std::any::Any; -use std::fmt::Formatter; -use std::io::Write; -use std::sync::Arc; #[derive(Debug)] pub struct ParquetSinkExec { diff --git a/native-engine/datafusion-ext-plans/src/project_exec.rs b/native-engine/datafusion-ext-plans/src/project_exec.rs index 72788032..a0d76452 100644 --- a/native-engine/datafusion-ext-plans/src/project_exec.rs +++ b/native-engine/datafusion-ext-plans/src/project_exec.rs @@ -15,28 +15,32 @@ // specific language governing permissions and limitations // under the License. -use crate::common::batch_statisitcs::{stat_input, InputBatchStatistics}; -use crate::common::cached_exprs_evaluator::CachedExprsEvaluator; -use crate::common::output::output_with_sender; -use crate::filter_exec::FilterExec; +use std::{any::Any, fmt::Formatter, sync::Arc}; + use arrow::datatypes::{Field, Fields, Schema, SchemaRef}; -use datafusion::common::{Result, Statistics}; -use datafusion::execution::TaskContext; -use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortExpr}; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +use datafusion::{ + common::{Result, Statistics}, + execution::TaskContext, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + }, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use futures::stream::once; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; +use futures::{stream::once, FutureExt, StreamExt, TryStreamExt}; use itertools::Itertools; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; -use crate::common::column_pruning::{prune_columns, ExecuteWithColumnPruning}; +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + cached_exprs_evaluator::CachedExprsEvaluator, + column_pruning::{prune_columns, ExecuteWithColumnPruning}, + output::TaskOutputter, + }, + filter_exec::FilterExec, +}; #[derive(Debug, Clone)] pub struct ProjectExec { @@ -122,17 +126,14 @@ impl ExecutionPlan for ProjectExec { partition: usize, context: Arc, ) -> Result { - let batch_size = context.session_config().batch_size(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - let exprs: Vec = self.expr.iter().map(|(e, _name)| e.clone()).collect(); let fut = if let Some(filter_exec) = self.input.as_any().downcast_ref::() { execute_project_with_filtering( filter_exec.children()[0].clone(), partition, - context, + context.clone(), self.schema(), filter_exec.predicates().to_vec(), exprs, @@ -143,7 +144,7 @@ impl ExecutionPlan for ProjectExec { execute_project_with_filtering( self.input.clone(), partition, - context, + context.clone(), self.schema(), vec![], exprs, @@ -156,9 +157,7 @@ impl ExecutionPlan for ProjectExec { self.schema(), once(fut).try_flatten(), )); - - let coalesced = Box::pin(CoalesceStream::new(output, batch_size, elapsed_compute)); - Ok(coalesced) + Ok(context.coalesce_with_default_batch_size(output, &baseline_metrics)?) } fn metrics(&self) -> Option { @@ -217,19 +216,14 @@ async fn execute_project_with_filtering( input.execute_projected(partition, context.clone(), &projection)?, )?; - output_with_sender( - "Project", - context, - output_schema.clone(), - move |sender| async move { - while let Some(batch) = input.next().await.transpose()? { - let mut timer = baseline_metrics.elapsed_compute().timer(); - let output_batch = - cached_expr_evaluator.filter_project(&batch, output_schema.clone())?; - baseline_metrics.record_output(output_batch.num_rows()); - sender.send(Ok(output_batch), Some(&mut timer)).await; - } - Ok(()) - }, - ) + context.output_with_sender("Project", output_schema.clone(), move |sender| async move { + while let Some(batch) = input.next().await.transpose()? { + let mut timer = baseline_metrics.elapsed_compute().timer(); + let output_batch = + cached_expr_evaluator.filter_project(&batch, output_schema.clone())?; + baseline_metrics.record_output(output_batch.num_rows()); + sender.send(Ok(output_batch), Some(&mut timer)).await; + } + Ok(()) + }) } diff --git a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs index 4297094e..e556351c 100644 --- a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs +++ b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs @@ -12,27 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::any::Any; -use std::fmt::Formatter; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; -use crate::agg::AGG_BUF_COLUMN_NAME; -use arrow::datatypes::{Field, Fields, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; +use arrow::{ + datatypes::{Field, Fields, Schema, SchemaRef}, + record_batch::RecordBatch, +}; use async_trait::async_trait; -use datafusion::error::DataFusionError; -use datafusion::error::Result; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, + }, }; use futures::{Stream, StreamExt}; +use crate::agg::AGG_BUF_COLUMN_NAME; + #[derive(Debug, Clone)] pub struct RenameColumnsExec { input: Arc, diff --git a/native-engine/datafusion-ext-plans/src/rss_shuffle_writer_exec.rs b/native-engine/datafusion-ext-plans/src/rss_shuffle_writer_exec.rs index 5e759de7..0a6c7561 100644 --- a/native-engine/datafusion-ext-plans/src/rss_shuffle_writer_exec.rs +++ b/native-engine/datafusion-ext-plans/src/rss_shuffle_writer_exec.rs @@ -14,37 +14,36 @@ //! Defines the External shuffle repartition plan -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; +use std::{any::Any, fmt::Debug, sync::Arc}; use async_trait::async_trait; -use datafusion::arrow::datatypes::SchemaRef; -use datafusion::arrow::error::ArrowError; - -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; - -use crate::common::memory_manager::MemManager; -use crate::shuffle::rss_bucket_repartitioner::RssBucketShuffleRepartitioner; -use crate::shuffle::rss_single_repartitioner::RssSingleShuffleRepartitioner; -use crate::shuffle::rss_sort_repartitioner::RssSortShuffleRepartitioner; -use crate::shuffle::{can_use_bucket_repartitioner, ShuffleRepartitioner}; use blaze_jni_bridge::{jni_call_static, jni_new_global_ref, jni_new_string}; -use datafusion::physical_plan::expressions::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; -use datafusion::physical_plan::metrics::{MetricBuilder, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::Partitioning; -use datafusion::physical_plan::SendableRecordBatchStream; -use datafusion::physical_plan::Statistics; -use datafusion::physical_plan::{DisplayAs, DisplayFormatType}; -use futures::stream::once; -use futures::{TryFutureExt, TryStreamExt}; - -/// The rss shuffle writer operator maps each input partition to M output partitions based on a -/// partitioning scheme. No guarantees are made about the order of the resulting partitions. +use datafusion::{ + arrow::datatypes::SchemaRef, + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, + }, +}; +use futures::{stream::once, TryStreamExt}; + +use crate::{ + memmgr::MemManager, + shuffle::{ + can_use_bucket_repartitioner, rss_bucket_repartitioner::RssBucketShuffleRepartitioner, + rss_single_repartitioner::RssSingleShuffleRepartitioner, + rss_sort_repartitioner::RssSortShuffleRepartitioner, ShuffleRepartitioner, + }, +}; + +/// The rss shuffle writer operator maps each input partition to M output +/// partitions based on a partitioning scheme. No guarantees are made about the +/// order of the resulting partitions. #[derive(Debug)] pub struct RssShuffleWriterExec { /// Input execution plan @@ -127,7 +126,7 @@ impl ExecutionPlan for RssShuffleWriterExec { rss_partition_writer, data_size_metric, )), - p @ Partitioning::Hash(_, _) + p @ Partitioning::Hash(..) if can_use_bucket_repartitioner(&self.input.schema()) && p.partition_count() < 200 => { @@ -142,7 +141,7 @@ impl ExecutionPlan for RssShuffleWriterExec { MemManager::register_consumer(partitioner.clone(), true); partitioner } - Partitioning::Hash(_, _) => { + Partitioning::Hash(..) => { let partitioner = Arc::new(RssSortShuffleRepartitioner::new( partition, rss_partition_writer, @@ -156,19 +155,14 @@ impl ExecutionPlan for RssShuffleWriterExec { } p => unreachable!("unsupported partitioning: {:?}", p), }; - - let stream = repartitioner - .execute( + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + once(repartitioner.execute( context.clone(), input, - context.session_config().batch_size(), BaselineMetrics::new(&self.metrics, partition), - ) - .map_err(|e| ArrowError::ExternalError(Box::new(e))); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(stream).try_flatten(), + )) + .try_flatten(), ))) } diff --git a/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs index 356c0f6a..3834ecca 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/bucket_repartitioner.rs @@ -14,26 +14,37 @@ //! Defines the sort-based shuffle writer -use crate::common::memory_manager::{MemConsumer, MemConsumerInfo, MemManager}; -use crate::common::onheap_spill::{try_new_spill, Spill}; -use crate::shuffle::{evaluate_hashes, evaluate_partition_ids, ShuffleRepartitioner, ShuffleSpill}; -use arrow::array::*; -use arrow::datatypes::*; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; +use std::{ + fs::{File, OpenOptions}, + io::{Cursor, Read, Seek, SeekFrom, Write}, + sync::{Arc, Weak}, +}; + +use arrow::{array::*, datatypes::*, error::Result as ArrowResult, record_batch::RecordBatch}; use async_trait::async_trait; -use datafusion::common::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::{BaselineMetrics, Count}; -use datafusion::physical_plan::Partitioning; -use datafusion_ext_commons::array_builder::{builder_extend, make_batch, new_array_builders}; -use datafusion_ext_commons::concat_batches; -use datafusion_ext_commons::io::write_one_batch; +use datafusion::{ + common::{DataFusionError, Result}, + execution::context::TaskContext, + physical_plan::{ + metrics::{BaselineMetrics, Count}, + Partitioning, + }, +}; +use datafusion_ext_commons::{ + array_builder::{builder_extend, make_batch, new_array_builders}, + concat_batches, + io::write_one_batch, +}; use futures::lock::Mutex; use itertools::Itertools; -use std::fs::{File, OpenOptions}; -use std::io::{Cursor, Read, Seek, SeekFrom, Write}; -use std::sync::{Arc, Weak}; + +use crate::{ + memmgr::{ + onheap_spill::{try_new_spill, Spill}, + MemConsumer, MemConsumerInfo, MemManager, + }, + shuffle::{evaluate_hashes, evaluate_partition_ids, ShuffleRepartitioner, ShuffleSpill}, +}; pub struct BucketShuffleRepartitioner { name: String, @@ -475,14 +486,14 @@ fn slot_size(len: usize, data_type: &DataType) -> usize { DataType::Float64 => len * 8, DataType::Date32 => len * 4, DataType::Date64 => len * 8, - DataType::Timestamp(_, _) => len * 8, + DataType::Timestamp(..) => len * 8, DataType::Time32(_) => len * 4, DataType::Time64(_) => len * 8, DataType::Binary => len * 4, DataType::LargeBinary => len * 8, DataType::Utf8 => len * 4, DataType::LargeUtf8 => len * 8, - DataType::Decimal128(_, _) => len * 16, + DataType::Decimal128(..) => len * 16, DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => len, DataType::Int16 => len * 2, diff --git a/native-engine/datafusion-ext-plans/src/shuffle/mod.rs b/native-engine/datafusion-ext-plans/src/shuffle/mod.rs index c80d88c4..4466159b 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/mod.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/mod.rs @@ -12,22 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::onheap_spill::Spill; -use crate::common::output::output_with_sender; -use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::RecordBatch; +use std::sync::Arc; + +use arrow::{datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch}; use async_trait::async_trait; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::BaselineMetrics; -use datafusion::physical_plan::{Partitioning, SendableRecordBatchStream}; -use datafusion_ext_commons::array_builder::has_array_builder_supported; -use datafusion_ext_commons::spark_hash::{create_hashes, pmod}; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; +use datafusion::{ + common::Result, + error::DataFusionError, + execution::context::TaskContext, + physical_plan::{metrics::BaselineMetrics, Partitioning, SendableRecordBatchStream}, +}; +use datafusion_ext_commons::{ + array_builder::has_array_builder_supported, + spark_hash::{create_hashes, pmod}, + streams::coalesce_stream::CoalesceInput, +}; use futures::StreamExt; -use std::sync::Arc; + +use crate::{common::output::TaskOutputter, memmgr::onheap_spill::Spill}; pub mod bucket_repartitioner; pub mod single_repartitioner; @@ -56,20 +58,15 @@ impl dyn ShuffleRepartitioner { self: Arc, context: Arc, input: SendableRecordBatchStream, - batch_size: usize, metrics: BaselineMetrics, ) -> Result { let input_schema = input.schema(); // coalesce input - let mut coalesced = Box::pin(CoalesceStream::new( - input, - batch_size, - metrics.elapsed_compute().clone(), - )); + let mut coalesced = context.coalesce_with_default_batch_size(input, &metrics)?; // process all input batches - output_with_sender("Shuffle", context, input_schema, |_| async move { + context.output_with_sender("Shuffle", input_schema, |_| async move { while let Some(batch) = coalesced.next().await.transpose()? { let _timer = metrics.elapsed_compute().timer(); metrics.record_output(batch.num_rows()); diff --git a/native-engine/datafusion-ext-plans/src/shuffle/rss.rs b/native-engine/datafusion-ext-plans/src/shuffle/rss.rs index 08b821f6..64648e59 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/rss.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/rss.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::io::Cursor; + use arrow::record_batch::RecordBatch; use blaze_jni_bridge::{jni_call, jni_new_direct_byte_buffer}; use datafusion::common::Result; use datafusion_ext_commons::io::write_one_batch; use jni::objects::GlobalRef; -use std::io::Cursor; pub fn rss_write_batch( rss_partition_writer: &GlobalRef, diff --git a/native-engine/datafusion-ext-plans/src/shuffle/rss_bucket_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/rss_bucket_repartitioner.rs index 00d99dcb..99a1ee0f 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/rss_bucket_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/rss_bucket_repartitioner.rs @@ -14,23 +14,28 @@ //! Defines the rss bucket shuffle repartitioner -use crate::common::memory_manager::{MemConsumer, MemConsumerInfo, MemManager}; -use crate::shuffle::rss::{rss_flush, rss_write_batch}; -use crate::shuffle::{evaluate_hashes, evaluate_partition_ids, ShuffleRepartitioner}; +use std::sync::{Arc, Weak}; + use async_trait::async_trait; -use datafusion::arrow::array::*; -use datafusion::arrow::datatypes::*; -use datafusion::arrow::error::Result as ArrowResult; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::Count; -use datafusion::physical_plan::Partitioning; +use datafusion::{ + arrow::{array::*, datatypes::*, error::Result as ArrowResult, record_batch::RecordBatch}, + common::Result, + execution::context::TaskContext, + physical_plan::{metrics::Count, Partitioning}, +}; use datafusion_ext_commons::array_builder::{builder_extend, make_batch, new_array_builders}; use futures::lock::Mutex; use itertools::Itertools; use jni::objects::GlobalRef; -use std::sync::{Arc, Weak}; + +use crate::{ + memmgr::{MemConsumer, MemConsumerInfo, MemManager}, + shuffle::{ + evaluate_hashes, evaluate_partition_ids, + rss::{rss_flush, rss_write_batch}, + ShuffleRepartitioner, + }, +}; pub struct RssBucketShuffleRepartitioner { name: String, diff --git a/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs index 56fd77b6..d354e872 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/rss_single_repartitioner.rs @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::shuffle::ShuffleRepartitioner; +use std::io::Cursor; + use async_trait::async_trait; use blaze_jni_bridge::{jni_call, jni_new_direct_byte_buffer}; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::physical_plan::metrics::Count; +use datafusion::{arrow::record_batch::RecordBatch, common::Result, physical_plan::metrics::Count}; use datafusion_ext_commons::io::write_one_batch; use jni::objects::GlobalRef; -use std::io::Cursor; + +use crate::shuffle::ShuffleRepartitioner; pub struct RssSingleShuffleRepartitioner { rss_partition_writer: GlobalRef, diff --git a/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs index 55ec95d2..e7149f8e 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/rss_sort_repartitioner.rs @@ -12,22 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::memory_manager::{MemConsumer, MemConsumerInfo, MemManager}; -use crate::common::BatchesInterleaver; -use crate::shuffle::rss::{rss_flush, rss_write_batch}; -use crate::shuffle::sort_repartitioner::PI; -use crate::shuffle::{evaluate_hashes, evaluate_partition_ids, ShuffleRepartitioner}; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use std::{ + mem::size_of, + sync::{Arc, Weak}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use async_trait::async_trait; -use datafusion::common::Result; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::Count; -use datafusion::physical_plan::Partitioning; +use datafusion::{ + common::Result, + execution::context::TaskContext, + physical_plan::{metrics::Count, Partitioning}, +}; use futures::lock::Mutex; use jni::objects::GlobalRef; -use std::mem::size_of; -use std::sync::{Arc, Weak}; + +use crate::{ + common::BatchesInterleaver, + memmgr::{MemConsumer, MemConsumerInfo, MemManager}, + shuffle::{ + evaluate_hashes, evaluate_partition_ids, + rss::{rss_flush, rss_write_batch}, + sort_repartitioner::PI, + ShuffleRepartitioner, + }, +}; pub struct RssSortShuffleRepartitioner { name: String, diff --git a/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs index d0d82265..e87fdd31 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/single_repartitioner.rs @@ -12,16 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::shuffle::ShuffleRepartitioner; +use std::{ + fs::{File, OpenOptions}, + io::{Seek, Write}, +}; + use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion::common::Result; -use datafusion::error::DataFusionError; -use datafusion::physical_plan::metrics::{BaselineMetrics, Count}; +use datafusion::{ + common::Result, + error::DataFusionError, + physical_plan::metrics::{BaselineMetrics, Count}, +}; use datafusion_ext_commons::io::write_one_batch; use once_cell::sync::OnceCell; -use std::fs::{File, OpenOptions}; -use std::io::{Seek, Write}; + +use crate::shuffle::ShuffleRepartitioner; pub struct SingleShuffleRepartitioner { output_data_file: String, diff --git a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs index d0705850..5356c89e 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs @@ -12,24 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::memory_manager::{MemConsumer, MemConsumerInfo, MemManager}; -use crate::common::onheap_spill::{try_new_spill, Spill}; -use crate::common::BatchesInterleaver; -use crate::shuffle::{evaluate_hashes, evaluate_partition_ids, ShuffleRepartitioner, ShuffleSpill}; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; +use std::{ + fs::{File, OpenOptions}, + io::{BufReader, Cursor, Read, Seek, Write}, + sync::{Arc, Weak}, +}; + +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use async_trait::async_trait; -use datafusion::common::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::metrics::{BaselineMetrics, Count}; -use datafusion::physical_plan::Partitioning; -use datafusion_ext_commons::io::write_one_batch; -use datafusion_ext_commons::loser_tree::LoserTree; +use datafusion::{ + common::{DataFusionError, Result}, + execution::context::TaskContext, + physical_plan::{ + metrics::{BaselineMetrics, Count}, + Partitioning, + }, +}; +use datafusion_ext_commons::{io::write_one_batch, loser_tree::LoserTree}; use derivative::Derivative; use futures::lock::Mutex; -use std::fs::{File, OpenOptions}; -use std::io::{BufReader, Cursor, Read, Seek, Write}; -use std::sync::{Arc, Weak}; + +use crate::{ + common::BatchesInterleaver, + memmgr::{ + onheap_spill::{try_new_spill, Spill}, + MemConsumer, MemConsumerInfo, MemManager, + }, + shuffle::{evaluate_hashes, evaluate_partition_ids, ShuffleRepartitioner, ShuffleSpill}, +}; pub struct SortShuffleRepartitioner { name: String, diff --git a/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs b/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs index 532ad377..705e04f3 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle_writer_exec.rs @@ -14,35 +14,36 @@ //! Defines the External shuffle repartition plan -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - -use crate::common::batch_statisitcs::{stat_input, InputBatchStatistics}; -use crate::common::memory_manager::MemManager; -use crate::shuffle::bucket_repartitioner::BucketShuffleRepartitioner; -use crate::shuffle::single_repartitioner::SingleShuffleRepartitioner; -use crate::shuffle::sort_repartitioner::SortShuffleRepartitioner; -use crate::shuffle::{can_use_bucket_repartitioner, ShuffleRepartitioner}; +use std::{any::Any, fmt::Debug, sync::Arc}; + use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError; use async_trait::async_trait; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_plan::expressions::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; -use datafusion::physical_plan::metrics::{MetricBuilder, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_plan::Partitioning; -use datafusion::physical_plan::SendableRecordBatchStream; -use datafusion::physical_plan::Statistics; -use datafusion::physical_plan::{DisplayAs, DisplayFormatType}; -use futures::stream::once; -use futures::{TryFutureExt, TryStreamExt}; - -/// The shuffle writer operator maps each input partition to M output partitions based on a -/// partitioning scheme. No guarantees are made about the order of the resulting partitions. +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + physical_plan::{ + expressions::PhysicalSortExpr, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, + }, +}; +use futures::{stream::once, TryStreamExt}; + +use crate::{ + common::batch_statisitcs::{stat_input, InputBatchStatistics}, + memmgr::MemManager, + shuffle::{ + bucket_repartitioner::BucketShuffleRepartitioner, can_use_bucket_repartitioner, + single_repartitioner::SingleShuffleRepartitioner, + sort_repartitioner::SortShuffleRepartitioner, ShuffleRepartitioner, + }, +}; + +/// The shuffle writer operator maps each input partition to M output partitions +/// based on a partitioning scheme. No guarantees are made about the order of +/// the resulting partitions. #[derive(Debug)] pub struct ShuffleWriterExec { /// Input execution plan @@ -119,7 +120,7 @@ impl ExecutionPlan for ShuffleWriterExec { BaselineMetrics::new(&self.metrics, partition), data_size_metric, )), - p @ Partitioning::Hash(_, _) + p @ Partitioning::Hash(..) if can_use_bucket_repartitioner(&self.input.schema()) && p.partition_count() < 200 => { @@ -136,7 +137,7 @@ impl ExecutionPlan for ShuffleWriterExec { MemManager::register_consumer(partitioner.clone(), true); partitioner } - Partitioning::Hash(_, _) => { + Partitioning::Hash(..) => { let partitioner = Arc::new(SortShuffleRepartitioner::new( partition, self.output_data_file.clone(), @@ -157,18 +158,14 @@ impl ExecutionPlan for ShuffleWriterExec { InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?, self.input.execute(partition, context.clone())?, )?; - let stream = repartitioner - .execute( + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + once(repartitioner.execute( context.clone(), input, - context.session_config().batch_size(), BaselineMetrics::new(&self.metrics, partition), - ) - .map_err(|e| ArrowError::ExternalError(Box::new(e))); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(stream).try_flatten(), + )) + .try_flatten(), ))) } diff --git a/native-engine/datafusion-ext-plans/src/sort_exec.rs b/native-engine/datafusion-ext-plans/src/sort_exec.rs index dce15729..07efd585 100644 --- a/native-engine/datafusion-ext-plans/src/sort_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_exec.rs @@ -14,51 +14,62 @@ //! Defines the External shuffle repartition plan -use crate::common::batch_statisitcs::{stat_input, InputBatchStatistics}; -use crate::common::bytes_arena::BytesArena; -use crate::common::column_pruning::ExecuteWithColumnPruning; -use crate::common::memory_manager::{MemConsumer, MemConsumerInfo, MemManager}; -use crate::common::onheap_spill::{try_new_spill, Spill}; -use crate::common::output::{ - output_bufferable_with_spill, output_with_sender, WrappedRecordBatchSender, +use std::{ + any::Any, + collections::VecDeque, + fmt::Formatter, + io::{BufReader, Cursor, Read, Write}, + mem::size_of, + sync::{Arc, Weak}, +}; + +use arrow::{ + array::ArrayRef, + datatypes::SchemaRef, + record_batch::RecordBatch, + row::{RowConverter, SortField}, }; -use crate::common::slim_bytes::SlimBytes; -use crate::common::{BatchTaker, BatchesInterleaver}; -use arrow::array::ArrayRef; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use arrow::row::{RowConverter, SortField}; use async_trait::async_trait; -use datafusion::common::{DataFusionError, Result, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +use datafusion::{ + common::{DataFusionError, Result, Statistics}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + }, }; -use datafusion_ext_commons::io::{ - read_bytes_slice, read_len, read_one_batch, write_len, write_one_batch, +use datafusion_ext_commons::{ + bytes_arena::BytesArena, + io::{read_bytes_slice, read_len, read_one_batch, write_len, write_one_batch}, + loser_tree::LoserTree, + slim_bytes::SlimBytes, + streams::coalesce_stream::CoalesceInput, }; -use datafusion_ext_commons::loser_tree::LoserTree; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use futures::lock::Mutex; -use futures::stream::once; -use futures::{StreamExt, TryStreamExt}; +use futures::{lock::Mutex, stream::once, StreamExt, TryStreamExt}; use itertools::Itertools; use lz4_flex::frame::FrameDecoder; use parking_lot::Mutex as SyncMutex; -use std::any::Any; -use std::collections::VecDeque; -use std::fmt::Formatter; -use std::io::{BufReader, Cursor, Read, Write}; -use std::mem::size_of; -use std::sync::{Arc, Weak}; + +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + column_pruning::ExecuteWithColumnPruning, + output::{TaskOutputter, WrappedRecordBatchSender}, + BatchTaker, BatchesInterleaver, + }, + memmgr::{ + onheap_spill::{try_new_spill, Spill}, + MemConsumer, MemConsumerInfo, MemManager, + }, +}; const NUM_LEVELS: usize = 64; // reserve memory for each spill -// estimated size: bufread=64KB + lz4dec.src=64KB + lz4dec.dest=64KB + batches=~100KB +// estimated size: bufread=64KB + lz4dec.src=64KB + lz4dec.dest=64KB + +// batches=~100KB const SPILL_OFFHEAP_MEM_COST: usize = 300000; #[derive(Debug)] @@ -242,25 +253,19 @@ impl ExecuteWithColumnPruning for SortExec { InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?, self.input.execute(partition, context.clone())?, )?; - let coalesced = Box::pin(CoalesceStream::new( + let coalesced = context.coalesce_with_default_batch_size( input, - batch_size, - BaselineMetrics::new(&self.metrics, partition) - .elapsed_compute() - .clone(), - )); + &BaselineMetrics::new(&self.metrics, partition), + )?; let output = Box::pin(RecordBatchStreamAdapter::new( self.schema(), - once(external_sort(coalesced, context, external_sorter)).try_flatten(), + once(external_sort(coalesced, context.clone(), external_sorter)).try_flatten(), )); - let coalesced = Box::pin(CoalesceStream::new( + let coalesced = context.coalesce_with_default_batch_size( output, - batch_size, - BaselineMetrics::new(&self.metrics, partition) - .elapsed_compute() - .clone(), - )); + &BaselineMetrics::new(&self.metrics, partition), + )?; Ok(coalesced) } } @@ -286,19 +291,14 @@ async fn external_sort( let has_spill = sorter.spills.lock().await.is_empty(); let sorter_cloned = sorter.clone(); - let output = output_with_sender( - "Sort", - context.clone(), - input.schema(), - |sender| async move { - sorter.output(sender).await?; - Ok(()) - }, - )?; + let output = context.output_with_sender("Sort", input.schema(), |sender| async move { + sorter.output(sender).await?; + Ok(()) + })?; // if running in-memory, buffer output when memory usage is high if !has_spill { - return output_bufferable_with_spill(sorter_cloned, context, output); + return context.output_bufferable_with_spill(sorter_cloned, output); } Ok(output) } @@ -877,20 +877,24 @@ fn max_level_id(levels: &[Option]) -> Option { #[cfg(test)] mod test { - use crate::sort_exec::SortExec; - use arrow::array::Int32Array; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::common::Result; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_expr::PhysicalSortExpr; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::{common, ExecutionPlan}; - use datafusion::prelude::SessionContext; use std::sync::Arc; + use arrow::{ + array::Int32Array, + compute::SortOptions, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_eq, + common::Result, + physical_expr::{expressions::Column, PhysicalSortExpr}, + physical_plan::{common, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + + use crate::sort_exec::SortExec; + fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -960,20 +964,20 @@ mod test { #[cfg(test)] mod fuzztest { - use crate::common::memory_manager::MemManager; - use crate::sort_exec::SortExec; - use arrow::compute::SortOptions; - use arrow::record_batch::RecordBatch; - use datafusion::common::{Result, ScalarValue}; - use datafusion::logical_expr::ColumnarValue; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_expr::math_expressions::random; - use datafusion::physical_expr::PhysicalSortExpr; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::prelude::{SessionConfig, SessionContext}; - use datafusion_ext_commons::concat_batches; use std::sync::Arc; + use arrow::{compute::SortOptions, record_batch::RecordBatch}; + use datafusion::{ + common::{Result, ScalarValue}, + logical_expr::ColumnarValue, + physical_expr::{expressions::Column, math_expressions::random, PhysicalSortExpr}, + physical_plan::memory::MemoryExec, + prelude::{SessionConfig, SessionContext}, + }; + use datafusion_ext_commons::concat_batches; + + use crate::{memmgr::MemManager, sort_exec::SortExec}; + #[tokio::test] async fn fuzztest() -> Result<()> { MemManager::init(10000); diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 75ee8199..e41f18ad 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs @@ -12,38 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::column_pruning::ExecuteWithColumnPruning; -use crate::common::output::{output_with_sender, WrappedRecordBatchSender}; -use crate::common::{BatchTaker, BatchesInterleaver}; -use arrow::array::*; -use arrow::buffer::NullBuffer; -use arrow::compute::{prep_null_mask_filter, SortOptions}; -use arrow::datatypes::{DataType, Schema, SchemaRef}; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use arrow::row::{Row, RowConverter, Rows, SortField}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::TaskContext; -use datafusion::logical_expr::JoinType; -use datafusion::logical_expr::JoinType::*; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::joins::utils::{ - build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter, JoinOn, JoinSide, +use std::{any::Any, cmp::Ordering, fmt::Formatter, sync::Arc}; + +use arrow::{ + array::*, + buffer::NullBuffer, + compute::{prep_null_mask_filter, SortOptions}, + datatypes::{DataType, Schema, SchemaRef}, + record_batch::{RecordBatch, RecordBatchOptions}, + row::{Row, RowConverter, Rows, SortField}, }; -use datafusion::physical_plan::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, ScopedTimerGuard, +use datafusion::{ + error::{DataFusionError, Result}, + execution::context::TaskContext, + logical_expr::{JoinType, JoinType::*}, + physical_expr::PhysicalSortExpr, + physical_plan::{ + joins::utils::{ + build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter, JoinOn, JoinSide, + }, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, ScopedTimerGuard}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, + Statistics, + }, }; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, - Statistics, -}; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; +use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex as SyncMutex; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::Formatter; -use std::sync::Arc; + +use crate::common::{ + column_pruning::ExecuteWithColumnPruning, + output::{TaskOutputter, WrappedRecordBatchSender}, + BatchTaker, BatchesInterleaver, +}; #[derive(Debug)] pub struct SortMergeJoinExec { @@ -61,7 +63,8 @@ pub struct SortMergeJoinExec { schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Sort options of join columns used in sorting left and right execution plans + /// Sort options of join columns used in sorting left and right execution + /// plans sort_options: Vec, } @@ -116,7 +119,8 @@ impl SortMergeJoinExec { .collect::>(); let sub_batch_size = batch_size / batch_size.ilog2() as usize; - // use smaller batch size and coalesce batches at the end, to avoid buffer overflowing + // use smaller batch size and coalesce batches at the end, to avoid buffer + // overflowing JoinParams { join_type: self.join_type, output_schema: self.schema(), @@ -350,25 +354,19 @@ fn execute_with_join_params( right: SendableRecordBatchStream, metrics: Arc, ) -> Result { - let batch_size = join_params.batch_size; let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); let output_schema = join_params.output_schema.clone(); let output_stream = Box::pin(RecordBatchStreamAdapter::new( join_params.output_schema.clone(), futures::stream::once(async move { - output_with_sender("SortMergeJoin", context, output_schema, move |sender| { + context_cloned.output_with_sender("SortMergeJoin", output_schema, move |sender| { execute_join(left, right, join_params, metrics_cloned, sender) }) }) .try_flatten(), )); - - let output_coalesced = Box::pin(CoalesceStream::new( - output_stream, - batch_size, - metrics.elapsed_compute().clone(), - )); - Ok(output_coalesced) + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) } async fn execute_join( @@ -869,24 +867,26 @@ fn compare_cursor( #[cfg(test)] mod tests { - use crate::sort_merge_join_exec::SortMergeJoinExec; - use arrow; - use arrow::array::*; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_sorted_eq; - use datafusion::error::Result; - use datafusion::logical_expr::JoinType; - use datafusion::logical_expr::JoinType::*; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_plan::common; - use datafusion::physical_plan::joins::utils::*; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::ExecutionPlan; - use datafusion::prelude::{SessionConfig, SessionContext}; use std::sync::Arc; + use arrow::{ + self, + array::*, + compute::SortOptions, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_sorted_eq, + error::Result, + logical_expr::{JoinType, JoinType::*}, + physical_expr::expressions::Column, + physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, + prelude::{SessionConfig, SessionContext}, + }; + + use crate::sort_merge_join_exec::SortMergeJoinExec; + fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } diff --git a/native-engine/datafusion-ext-plans/src/window/mod.rs b/native-engine/datafusion-ext-plans/src/window/mod.rs index 23d8b818..4409fad6 100644 --- a/native-engine/datafusion-ext-plans/src/window/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/mod.rs @@ -12,18 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::{create_agg, AggFunction}; -use crate::window::processors::agg_processor::AggProcessor; -use crate::window::processors::rank_processor::RankProcessor; -use crate::window::processors::row_number_processor::RowNumberProcessor; -use crate::window::window_context::WindowContext; -use arrow::array::ArrayRef; -use arrow::datatypes::FieldRef; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; -use datafusion::physical_expr::PhysicalExpr; use std::sync::Arc; +use arrow::{array::ArrayRef, datatypes::FieldRef, record_batch::RecordBatch}; +use datafusion::{common::Result, physical_expr::PhysicalExpr}; + +use crate::{ + agg::{create_agg, AggFunction}, + window::{ + processors::{ + agg_processor::AggProcessor, rank_processor::RankProcessor, + row_number_processor::RowNumberProcessor, + }, + window_context::WindowContext, + }, +}; + pub mod processors; pub mod window_context; diff --git a/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs index b90d0b5a..d41b9a8f 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/agg_processor.rs @@ -12,16 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::agg::agg_buf::{create_agg_buf_from_initial_value, AggBuf}; -use crate::agg::Agg; -use crate::common::slim_bytes::SlimBytes; -use crate::window::window_context::WindowContext; -use crate::window::WindowFunctionProcessor; -use arrow::array::ArrayRef; -use arrow::record_batch::RecordBatch; -use datafusion::common::{Result, ScalarValue}; use std::sync::Arc; +use arrow::{array::ArrayRef, record_batch::RecordBatch}; +use datafusion::common::{Result, ScalarValue}; +use datafusion_ext_commons::slim_bytes::SlimBytes; + +use crate::{ + agg::{ + agg_buf::{create_agg_buf_from_initial_value, AggBuf}, + Agg, + }, + window::{window_context::WindowContext, WindowFunctionProcessor}, +}; + pub struct AggProcessor { cur_partition: SlimBytes, agg: Arc, diff --git a/native-engine/datafusion-ext-plans/src/window/processors/rank_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/rank_processor.rs index b8afd015..e476fc78 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/rank_processor.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/rank_processor.rs @@ -12,14 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::slim_bytes::SlimBytes; -use crate::window::window_context::WindowContext; -use crate::window::WindowFunctionProcessor; -use arrow::array::{ArrayRef, Int32Builder}; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; use std::sync::Arc; +use arrow::{ + array::{ArrayRef, Int32Builder}, + record_batch::RecordBatch, +}; +use datafusion::common::Result; +use datafusion_ext_commons::slim_bytes::SlimBytes; + +use crate::window::{window_context::WindowContext, WindowFunctionProcessor}; + pub struct RankProcessor { cur_partition: SlimBytes, cur_order: SlimBytes, diff --git a/native-engine/datafusion-ext-plans/src/window/processors/row_number_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/row_number_processor.rs index 734e9691..92e0ec46 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/row_number_processor.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/row_number_processor.rs @@ -12,13 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::window::window_context::WindowContext; -use crate::window::WindowFunctionProcessor; -use arrow::array::{ArrayRef, Int32Builder}; -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; use std::sync::Arc; +use arrow::{ + array::{ArrayRef, Int32Builder}, + record_batch::RecordBatch, +}; +use datafusion::common::Result; + +use crate::window::{window_context::WindowContext, WindowFunctionProcessor}; + pub struct RowNumberProcessor { cur_partition: Box<[u8]>, cur_row_number: i32, diff --git a/native-engine/datafusion-ext-plans/src/window/window_context.rs b/native-engine/datafusion-ext-plans/src/window/window_context.rs index 260bf1c0..bc6f656b 100644 --- a/native-engine/datafusion-ext-plans/src/window/window_context.rs +++ b/native-engine/datafusion-ext-plans/src/window/window_context.rs @@ -12,14 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::window::WindowExpr; -use arrow::datatypes::{Field, FieldRef, Fields, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use arrow::row::{RowConverter, Rows, SortField}; -use datafusion::common::Result; -use datafusion::physical_expr::{PhysicalExpr, PhysicalSortExpr}; use std::sync::{Arc, Mutex as SyncMutex}; +use arrow::{ + datatypes::{Field, FieldRef, Fields, Schema, SchemaRef}, + record_batch::RecordBatch, + row::{RowConverter, Rows, SortField}, +}; +use datafusion::{ + common::Result, + physical_expr::{PhysicalExpr, PhysicalSortExpr}, +}; + +use crate::window::WindowExpr; + #[derive(Debug)] pub struct WindowContext { pub window_exprs: Vec, diff --git a/native-engine/datafusion-ext-plans/src/window_exec.rs b/native-engine/datafusion-ext-plans/src/window_exec.rs index c2223bfa..b11a6ee5 100644 --- a/native-engine/datafusion-ext-plans/src/window_exec.rs +++ b/native-engine/datafusion-ext-plans/src/window_exec.rs @@ -12,28 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::common::output::output_with_sender; -use crate::window::window_context::WindowContext; -use crate::window::{WindowExpr, WindowFunctionProcessor}; -use arrow::array::ArrayRef; -use arrow::datatypes::SchemaRef; -use arrow::error::ArrowError; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion::common::{Result, Statistics}; -use datafusion::execution::context::TaskContext; -use datafusion::physical_expr::PhysicalSortExpr; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, - SendableRecordBatchStream, +use std::{any::Any, fmt::Formatter, sync::Arc}; + +use arrow::{ + array::ArrayRef, + datatypes::SchemaRef, + error::ArrowError, + record_batch::{RecordBatch, RecordBatchOptions}, +}; +use datafusion::{ + common::{Result, Statistics}, + execution::context::TaskContext, + physical_expr::PhysicalSortExpr, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, + SendableRecordBatchStream, + }, +}; +use datafusion_ext_commons::streams::coalesce_stream::CoalesceInput; +use futures::{stream::once, StreamExt, TryFutureExt, TryStreamExt}; + +use crate::{ + common::output::TaskOutputter, + window::{window_context::WindowContext, WindowExpr, WindowFunctionProcessor}, }; -use datafusion_ext_commons::streams::coalesce_stream::CoalesceStream; -use futures::stream::once; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; #[derive(Debug)] pub struct WindowExec { @@ -109,13 +113,10 @@ impl ExecutionPlan for WindowExec { ) -> Result { // at this moment only supports ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW let input = self.input.execute(partition, context.clone())?; - let coalesced = Box::pin(CoalesceStream::new( + let coalesced = context.coalesce_with_default_batch_size( input, - context.session_config().batch_size(), - BaselineMetrics::new(&self.metrics, partition) - .elapsed_compute() - .clone(), - )); + &BaselineMetrics::new(&self.metrics, partition), + )?; let stream = execute_window( coalesced, @@ -153,57 +154,55 @@ async fn execute_window( .collect::>()?; // start processing input batches - output_with_sender( - "Window", - task_context, - context.output_schema.clone(), - |sender| async move { - while let Some(batch) = input.next().await.transpose()? { - let elapsed_time = metrics.elapsed_compute().clone(); - let mut timer = elapsed_time.timer(); + let output_schema = context.output_schema.clone(); + task_context.output_with_sender("Window", output_schema, |sender| async move { + while let Some(batch) = input.next().await.transpose()? { + let elapsed_time = metrics.elapsed_compute().clone(); + let mut timer = elapsed_time.timer(); - let window_cols: Vec = processors - .iter_mut() - .map(|processor| { - if context.partition_spec.is_empty() { - processor.process_batch_without_partitions(&context, &batch) - } else { - processor.process_batch(&context, &batch) - } - }) - .collect::>()?; + let window_cols: Vec = processors + .iter_mut() + .map(|processor| { + if context.partition_spec.is_empty() { + processor.process_batch_without_partitions(&context, &batch) + } else { + processor.process_batch(&context, &batch) + } + }) + .collect::>()?; - let output_cols = [batch.columns().to_vec(), window_cols].concat(); - let output_batch = RecordBatch::try_new_with_options( - context.output_schema.clone(), - output_cols, - &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), - )?; + let output_cols = [batch.columns().to_vec(), window_cols].concat(); + let output_batch = RecordBatch::try_new_with_options( + context.output_schema.clone(), + output_cols, + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?; - metrics.record_output(output_batch.num_rows()); - sender.send(Ok(output_batch), Some(&mut timer)).await; - } - Ok(()) - }, - ) + metrics.record_output(output_batch.num_rows()); + sender.send(Ok(output_batch), Some(&mut timer)).await; + } + Ok(()) + }) } #[cfg(test)] mod test { - use crate::agg::AggFunction; - use crate::window::{WindowExpr, WindowFunction, WindowRankType}; - use crate::window_exec::WindowExec; - use arrow::array::*; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion::assert_batches_eq; - use datafusion::physical_expr::expressions::Column; - use datafusion::physical_expr::PhysicalSortExpr; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::ExecutionPlan; - use datafusion::prelude::SessionContext; use std::sync::Arc; + use arrow::{array::*, datatypes::*, record_batch::RecordBatch}; + use datafusion::{ + assert_batches_eq, + physical_expr::{expressions::Column, PhysicalSortExpr}, + physical_plan::{memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + + use crate::{ + agg::AggFunction, + window::{WindowExpr, WindowFunction, WindowRankType}, + window_exec::WindowExec, + }; + fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), diff --git a/pom.xml b/pom.xml index c5a420ef..ca41a708 100644 --- a/pom.xml +++ b/pom.xml @@ -15,7 +15,7 @@ 2.0.7-SNAPSHOT UTF-8 - 13.0.0 + 15.0.0-SNAPSHOT 3.21.9 @@ -74,6 +74,17 @@ + + + + true + + arrow-snapshot + https://nightlies.apache.org/arrow/java/ + default + + + diff --git a/rustfmt.toml b/rustfmt.toml index bf73280e..c509909c 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,4 +1,16 @@ edition = "2021" +unstable_features = true + max_width = 100 -comment_width = 80 -array_width = 80 +wrap_comments = true +format_code_in_doc_comments = true +format_macro_bodies = true +format_macro_matchers = true +normalize_comments = true +normalize_doc_attributes = true +condense_wildcard_suffixes = true +newline_style = "Unix" +use_field_init_shorthand = true +use_try_shorthand = true +imports_granularity = "Crate" +group_imports = "StdExternalCrate" diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java index 3b295ba5..3cf35eff 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java @@ -18,69 +18,92 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv$; -public class BlazeConf { +@SuppressWarnings("unused") +public enum BlazeConf { /// suggested batch size for arrow batches. - public static int batchSize() { - return intConf("spark.blaze.batchSize", 10000); - } + BATCH_SIZE("spark.blaze.batchSize", 10000), /// suggested fraction of off-heap memory used in native execution. /// actual off-heap memory usage is expected to be spark.executor.memoryOverhead * fraction. - public static double memoryFraction() { - return doubleConf("spark.blaze.memoryFraction", 0.6); - } + MEMORY_FRACTION("spark.blaze.memoryFraction", 0.6), /// translates inequality smj to native. improves performance in most cases, however some /// issues are found in special cases, like tpcds q72. - public static boolean enableSmjInequalityJoin() { - return booleanConf("spark.blaze.enable.smjInequalityJoin", false); - } + SMJ_INEQUALITY_JOIN_ENABLE("spark.blaze.enable.smjInequalityJoin", false), /// fallbacks to SortMergeJoin when executing BroadcastHashJoin with big broadcasted table. - public static boolean enableBhjFallbacksToSmj() { - return booleanConf("spark.blaze.enable.bhjFallbacksToSmj", true); - } + BHJ_FALLBACKS_TO_SMJ_ENABLE("spark.blaze.enable.bhjFallbacksToSmj", true), /// fallbacks to SortMergeJoin when BroadcastHashJoin has a broadcasted table with rows more /// than this threshold. requires spark.blaze.enable.bhjFallbacksToSmj = true. - public static int bhjFallbacksToSmjRowsThreshold() { - return intConf("spark.blaze.bhjFallbacksToSmj.rows", 1000000); - } + BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD("spark.blaze.bhjFallbacksToSmj.rows", 1000000), /// fallbacks to SortMergeJoin when BroadcastHashJoin has a broadcasted table with memory usage /// more than this threshold. requires spark.blaze.enable.bhjFallbacksToSmj = true. - public static int bhjFallbacksToSmjMemThreshold() { - return intConf("spark.blaze.bhjFallbacksToSmj.mem.bytes", 134217728); - } + BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD("spark.blaze.bhjFallbacksToSmj.mem.bytes", 134217728), /// enable converting upper/lower functions to native, special cases may provide different /// outputs from spark due to different unicode versions. - public static boolean enableCaseConvertFunctions() { - return booleanConf("spark.blaze.enable.caseconvert.functions", false); + CASE_CONVERT_FUNCTIONS_ENABLE("spark.blaze.enable.caseconvert.functions", false), + + /// number of threads evaluating UDFs + /// improves performance for special case that UDF concurrency matters + UDF_WRAPPER_NUM_THREADS("spark.blaze.udfWrapperNumThreads", 1), + + /// enable extra metrics of input batch statistics + INPUT_BATCH_STATISTICS_ENABLE("spark.blaze.enableInputBatchStatistics", false), + + /// ignore corrupted input files + IGNORE_CORRUPTED_FILES("spark.files.ignoreCorruptFiles", false), + + /// enable partial aggregate skipping (see https://github.com/blaze-init/blaze/issues/327) + PARTIAL_AGG_SKIPPING_ENABLE("spark.blaze.partialAggSkipping.enable", true), + + /// partial aggregate skipping ratio + PARTIAL_AGG_SKIPPING_RATIO("spark.blaze.partialAggSkipping.ratio", 0.8), + + /// mininum number of rows to trigger partial aggregate skipping + PARTIAL_AGG_SKIPPING_MIN_ROWS("spark.blaze.partialAggSkipping.minRows", BATCH_SIZE.intConf() * 2), + ; + + private String key; + private Object defaultValue; + + BlazeConf(String key, Object defaultValue) { + this.key = key; + this.defaultValue = defaultValue; + } + + public boolean booleanConf() { + return conf().getBoolean(key, (boolean) defaultValue); + } + + public int intConf() { + return conf().getInt(key, (int) defaultValue); } - public static int udfWrapperNumThreads() { - return intConf("spark.blaze.udfWrapperNumThreads", 1); + public long longConf() { + return conf().getLong(key, (long) defaultValue); } - public static boolean enableInputBatchStatistics() { - return booleanConf("spark.blaze.enableInputBatchStatistics", false); + public double doubleConf() { + return conf().getDouble(key, (double) defaultValue); } - public static boolean ignoreCorruptedFiles() { - return booleanConf("spark.files.ignoreCorruptFiles", false); + public static boolean booleanConf(String confName) { + return BlazeConf.valueOf(confName).booleanConf(); } - private static int intConf(String key, int defaultValue) { - return conf().getInt(key, defaultValue); + public static int intConf(String confName) { + return BlazeConf.valueOf(confName).intConf(); } - private static double doubleConf(String key, double defaultValue) { - return conf().getDouble(key, defaultValue); + public static long longConf(String confName) { + return BlazeConf.valueOf(confName).longConf(); } - private static boolean booleanConf(String key, boolean defaultValue) { - return conf().getBoolean(key, defaultValue); + public static double doubleConf(String confName) { + return BlazeConf.valueOf(confName).doubleConf(); } private static SparkConf conf() { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala index 83498782..ca75da28 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala @@ -20,19 +20,27 @@ import java.io.IOException import java.nio.file.Files import java.nio.file.StandardCopyOption import java.util.concurrent.atomic.AtomicReference - -import org.apache.spark.InterruptibleIterator +import java.util.concurrent.ArrayBlockingQueue + +import org.apache.arrow.c.ArrowArray +import org.apache.arrow.c.ArrowSchema +import org.apache.arrow.c.CDataDictionaryProvider +import org.apache.arrow.c.Data +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.Partition import org.apache.spark.TaskContext -import org.blaze.protobuf.PartitionId -import org.blaze.protobuf.PhysicalPlanNode -import org.blaze.protobuf.TaskDefinition - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.blaze.arrowio.ArrowFFIStreamImportIterator +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.ColumnarHelper import org.apache.spark.util.CompletionIterator import org.apache.spark.util.Utils +import org.blaze.protobuf.PartitionId +import org.blaze.protobuf.PhysicalPlanNode +import org.blaze.protobuf.TaskDefinition case class BlazeCallNativeWrapper( nativePlan: PhysicalPlanNode, @@ -44,15 +52,33 @@ case class BlazeCallNativeWrapper( BlazeCallNativeWrapper.initNative() private val error: AtomicReference[Throwable] = new AtomicReference(null) - private var arrowFFIStreamPtr = 0L + private val dictionaryProvider = new CDataDictionaryProvider() + private val recordsQueue = new ArrayBlockingQueue[Option[UnsafeRow]](256) + private var arrowSchema: Schema = _ logInfo(s"Start executing native plan") private var nativeRuntimePtr = JniBridge.callNative(this) - private var rowIterator = { - val iter = new ArrowFFIStreamImportIterator(context, arrowFFIStreamPtr, checkError) - context match { - case Some(tc) => new InterruptibleIterator[InternalRow](tc, iter) - case None => iter + + private lazy val rowIterator = new Iterator[InternalRow] { + private var currentRecord: InternalRow = _ + + override def hasNext: Boolean = { + if (currentRecord != null) { + return true + } + recordsQueue.take() match { + case Some(row) => + currentRecord = row + true + case None => + false + } + } + + override def next(): InternalRow = { + val nextRecord = currentRecord + currentRecord = null + nextRecord } } @@ -66,21 +92,54 @@ case class BlazeCallNativeWrapper( protected def getMetrics: MetricNode = metrics + protected def importSchema(ffiSchemaPtr: Long): Unit = { + val ffiSchema = ArrowSchema.wrap(ffiSchemaPtr) + arrowSchema = Data.importSchema(ArrowUtils.rootAllocator, ffiSchema, dictionaryProvider) + } + + protected def importBatch(ffiArrayPtr: Long): Unit = { + checkError() + + if (ffiArrayPtr == 0) { + recordsQueue.put(None) // finished + return + } + val root: VectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + val ffiArray = ArrowArray.wrap(ffiArrayPtr) + Utils.tryWithSafeFinally { + Data.importIntoVectorSchemaRoot( + ArrowUtils.rootAllocator, + ffiArray, + root, + dictionaryProvider) + + val toUnsafe = UnsafeProjection.create(ArrowUtils.fromArrowSchema(root.getSchema)) + toUnsafe.initialize(Option(TaskContext.get()).map(_.partitionId()).getOrElse(0)) + + val batch = ColumnarHelper.rootAsBatch(root) + for (row <- ColumnarHelper.batchAsRowIter(batch)) { + checkError() + recordsQueue.put(Some(toUnsafe(row).copy())) + } + } { + root.close() + ffiArray.close() + } + } + protected def setError(error: Throwable): Unit = { this.error.set(error) + terminateRecordsQueue() } protected def checkError(): Unit = { val throwable = error.getAndSet(null) if (throwable != null) { + terminateRecordsQueue() throw throwable } } - protected def setArrowFFIStreamPtr(ptr: Long): Unit = { - this.arrowFFIStreamPtr = ptr - } - protected def getRawTaskDefinition: Array[Byte] = { val partitionId: PartitionId = PartitionId .newBuilder() @@ -97,21 +156,18 @@ case class BlazeCallNativeWrapper( taskDefinition.toByteArray } + private def terminateRecordsQueue(): Unit = { + recordsQueue.clear() + recordsQueue.put(None) + } + private def close(): Unit = { synchronized { - if (rowIterator != null) { - rowIterator match { - case iter: InterruptibleIterator[_] => - iter.delegate.asInstanceOf[ArrowFFIStreamImportIterator].close() - case iter: ArrowFFIStreamImportIterator => - iter.close() - } - rowIterator = null - } if (nativeRuntimePtr != 0) { JniBridge.finalizeNative(nativeRuntimePtr) nativeRuntimePtr = 0 } + terminateRecordsQueue() checkError() } } @@ -125,9 +181,9 @@ object BlazeCallNativeWrapper extends Logging { private lazy val lazyInitNative: Unit = { logInfo( "Initializing native environment (" + - s"batchSize=${BlazeConf.batchSize}, " + + s"batchSize=${BlazeConf.BATCH_SIZE.intConf()}, " + s"nativeMemory=${NativeHelper.nativeMemory}, " + - s"memoryFraction=${BlazeConf.memoryFraction}") + s"memoryFraction=${BlazeConf.MEMORY_FRACTION.doubleConf()}") BlazeCallNativeWrapper.loadLibBlaze() JniBridge.initNative(NativeHelper.nativeMemory) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 09556d10..4de65791 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -780,9 +780,9 @@ object NativeConverters extends Logging { case Length(arg) if arg.dataType == StringType => buildScalarFunction(pb.ScalarFunction.CharacterLength, arg :: Nil, IntegerType) - case e: Lower if BlazeConf.enableCaseConvertFunctions() => + case e: Lower if BlazeConf.CASE_CONVERT_FUNCTIONS_ENABLE.booleanConf() => buildExtScalarFunction("StringLower", e.children, e.dataType) - case e: Upper if BlazeConf.enableCaseConvertFunctions() => + case e: Upper if BlazeConf.CASE_CONVERT_FUNCTIONS_ENABLE.booleanConf() => buildExtScalarFunction("StringUpper", e.children, e.dataType) case e: StringTrim => @@ -864,13 +864,11 @@ object NativeConverters extends Logging { case e: Concat if e.children.forall(_.dataType == StringType) => buildExtScalarFunction("StringConcat", e.children, e.dataType) - case e: ConcatWs if e.children.nonEmpty => - assert( - e.children.head.isInstanceOf[Literal], - "only supports concat_ws with literal seperator") - assert( - e.children.forall(c => c.dataType == StringType || c.dataType == ArrayType(StringType)), - "only supports concat_ws with string or array type") + case e: ConcatWs + if e.children.nonEmpty + && e.children.head.isInstanceOf[Literal] + && e.children.forall(c => + c.dataType == StringType || c.dataType == ArrayType(StringType)) => buildExtScalarFunction("StringConcatWs", e.children, e.dataType) case e: Coalesce => buildScalarFunction(pb.ScalarFunction.Coalesce, e.children, e.dataType) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala index 97834a84..2bc35183 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala @@ -86,7 +86,7 @@ object NativeHelper extends Logging { "join_time" -> SQLMetrics.createNanoTimingMetric(sc, "Native.join_time"), "spilled_bytes" -> SQLMetrics.createSizeMetric(sc, "Native.spilled_bytes")) - if (BlazeConf.enableInputBatchStatistics()) { + if (BlazeConf.INPUT_BATCH_STATISTICS_ENABLE.booleanConf()) { metrics ++= TreeMap( "input_batch_count" -> SQLMetrics.createMetric(sc, "Native.input_batches"), "input_row_count" -> SQLMetrics.createMetric(sc, "Native.input_rows"), diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExportIterator.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExportIterator.scala index 601daebe..33700e0a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExportIterator.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIExportIterator.scala @@ -32,7 +32,7 @@ class ArrowFFIExportIterator( rowIter: Iterator[InternalRow], schema: StructType, taskContext: TaskContext, - recordBatchSize: Int = BlazeConf.batchSize) + recordBatchSize: Int = BlazeConf.BATCH_SIZE.intConf()) extends Iterator[(Long, Long) => Unit] with Logging { diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala index d7a34486..7c8e7640 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/ArrowFFIStreamImportIterator.scala @@ -75,7 +75,7 @@ class ArrowFFIStreamImportIterator( } finally { // current batch can be closed after all rows converted to UnsafeRow currentBatch.close() - reader.getVectorSchemaRoot.clear() + // reader.getVectorSchemaRoot.clear() } hasNext } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala index 85ffc81c..a14e6392 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala @@ -82,7 +82,7 @@ abstract class ConvertToNativeBase(override val child: SparkPlan) inputRowIter, renamedSchema, context, - recordBatchSize = BlazeConf.batchSize / 4) + recordBatchSize = BlazeConf.BATCH_SIZE.intConf() / 4) new InterruptibleIterator(context, exportIter) }) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggBase.scala index fabf68d3..0bfd93f6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeAggBase.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.OneToOneDependency import org.apache.spark.internal.Logging +import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper @@ -135,6 +136,18 @@ abstract class NativeAggBase( override def outputPartitioning: Partitioning = child.outputPartitioning + private val supportsPartialSkipping = ( + BlazeConf.PARTIAL_AGG_SKIPPING_ENABLE.booleanConf() + && (child match { // do not trigger skipping after ExpandExec + case _: NativeExpandBase => false + case c: NativeProjectBase if c.child.isInstanceOf[NativeExpandBase] => false + case _ => true + }) + && initialInputBufferOffset == 0 + && aggregateExpressions.forall(_.mode == Partial) + && requiredChildDistribution.forall(_ == UnspecifiedDistribution) + ) + override def doExecuteNative(): NativeRDD = { val inputRDD = NativeHelper.executeNative(child) val nativeMetrics = MetricNode(metrics, inputRDD.metrics :: Nil) @@ -155,7 +168,6 @@ abstract class NativeAggBase( lazy val inputPlan = inputRDD.nativePlan(inputRDD.partitions(partition.index), taskContext) - pb.PhysicalPlanNode .newBuilder() .setAgg( @@ -168,6 +180,7 @@ abstract class NativeAggBase( .addAllAggExpr(nativeAggrs.asJava) .addAllGroupingExpr(nativeGroupingExprs.asJava) .setInitialInputBufferOffset(initialInputBufferOffset) + .setSupportsPartialSkipping(supportsPartialSkipping) .setInput(inputPlan)) .build() }, diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala index bec651fa..7ab0347d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala @@ -265,7 +265,7 @@ object NativeBroadcastExchangeBase { keys: Seq[Expression], nativeSchema: pb.Schema): Array[Array[Byte]] = { - if (!BlazeConf.enableBhjFallbacksToSmj() || keys.isEmpty) { + if (!BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || keys.isEmpty) { return collectedData // no need to sort data in driver side } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala index cefe5a4c..642bc60e 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala @@ -51,8 +51,8 @@ abstract class NativeBroadcastJoinBase( "Semi/Anti join with filter is not supported yet") assert( - !BlazeConf.enableBhjFallbacksToSmj() || BlazeConf - .enableSmjInequalityJoin() || condition.isEmpty, + !BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE + .booleanConf() || condition.isEmpty, "Join filter is not supported when BhjFallbacksToSmj and SmjInequalityJoin both enabled") override lazy val metrics: Map[String, SQLMetric] = Map( diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala index a0602d16..1df7b96d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala @@ -54,7 +54,7 @@ abstract class NativeSortMergeJoinBase( "Semi/Anti join with filter is not supported yet") assert( - BlazeConf.enableSmjInequalityJoin() || condition.isEmpty, + BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE.booleanConf() || condition.isEmpty, "inequality sort-merge join is not enabled") override lazy val metrics: Map[String, SQLMetric] = Map(