From 4e8ac98fbbebbf965eebba5cc40ecf7c590a6d28 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Mar 2024 13:37:17 -0400 Subject: [PATCH 01/35] chore(deps-dev): bump follow-redirects (#9609) Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.15.4 to 1.15.6. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](https://github.com/follow-redirects/follow-redirects/compare/v1.15.4...v1.15.6) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../wasmtest/datafusion-wasm-app/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 5163c99bd5ac..aac87845bc9f 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -1666,9 +1666,9 @@ } }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -5580,9 +5580,9 @@ } }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "forwarded": { From 449738cd41158cb7cf65ad98abb8fda882256586 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 19 Mar 2024 02:21:07 +0800 Subject: [PATCH 02/35] move array_replace family functions to datafusion-function-array crate (#9651) * Add array replace functions * fix ci * fix ci * Update dependencies in Cargo.lock file * Fix formatting in comment * fix ci * rename mod * fix conflict * remove duplicated function * fix: clippy --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 47 +- datafusion/core/benches/array_expression.rs | 42 +- datafusion/expr/src/built_in_function.rs | 28 -- datafusion/expr/src/expr_fn.rs | 24 - datafusion/functions-array/Cargo.toml | 1 + datafusion/functions-array/src/core.rs | 2 +- datafusion/functions-array/src/lib.rs | 9 +- datafusion/functions-array/src/position.rs | 106 +---- datafusion/functions-array/src/replace.rs | 362 +++++++++++++++ datafusion/functions-array/src/utils.rs | 6 +- .../physical-expr/src/array_expressions.rs | 423 ------------------ datafusion/physical-expr/src/functions.rs | 16 +- datafusion/physical-expr/src/lib.rs | 1 - datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 25 +- datafusion/proto/src/logical_plan/to_proto.rs | 3 - .../tests/cases/roundtrip_logical_plan.rs | 8 + 19 files changed, 435 insertions(+), 695 deletions(-) create mode 100644 datafusion/functions-array/src/replace.rs delete mode 100644 datafusion/physical-expr/src/array_expressions.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index deda497d9dd3..8e2a2c353e2d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -378,13 +378,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -776,9 +776,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1080,7 +1080,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad291aa74992b9b7a7e88c38acbbf6ad7e107f1d90ee8775b7bc1fc3394f485c" dependencies = [ "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1270,6 +1270,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "datafusion-common", "datafusion-execution", @@ -1639,7 +1640,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1713,9 +1714,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb" dependencies = [ "bytes", "fnv", @@ -2560,7 +2561,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3101,7 +3102,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3236,7 +3237,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3282,7 +3283,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3295,7 +3296,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3317,9 +3318,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -3403,7 +3404,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3498,7 +3499,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3595,7 +3596,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3640,7 +3641,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3794,7 +3795,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", "wasm-bindgen-shared", ] @@ -3828,7 +3829,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4086,7 +4087,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs index 95bc93e0e353..c980329620aa 100644 --- a/datafusion/core/benches/array_expression.rs +++ b/datafusion/core/benches/array_expression.rs @@ -22,48 +22,32 @@ extern crate datafusion; mod data_utils; use crate::criterion::Criterion; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::{ArrayRef, Int64Array, ListArray}; -use datafusion_physical_expr::array_expressions; -use std::sync::Arc; +use datafusion::functions_array::expr_fn::{array_replace_all, make_array}; +use datafusion_expr::lit; fn criterion_benchmark(c: &mut Criterion) { // Construct large arrays for benchmarking let array_len = 100000000; - let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); - let list_array = ListArray::from_iter_primitive::(vec![ - Some(array.clone()), - Some(array.clone()), - Some(array), - ]); - let from_array = Int64Array::from_value(2, 3); - let to_array = Int64Array::from_value(-2, 3); + let array = (0..array_len).map(|_| lit(2_i64)).collect::>(); + let list_array = make_array(vec![make_array(array); 3]); + let from_array = make_array(vec![lit(2_i64); 3]); + let to_array = make_array(vec![lit(-2_i64); 3]); - let args = vec![ - Arc::new(list_array) as ArrayRef, - Arc::new(from_array) as ArrayRef, - Arc::new(to_array) as ArrayRef, - ]; - - let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); - let expected_array = ListArray::from_iter_primitive::(vec![ - Some(array.clone()), - Some(array.clone()), - Some(array), - ]); + let expected_array = list_array.clone(); // Benchmark array functions c.bench_function("array_replace", |b| { b.iter(|| { assert_eq!( - array_expressions::array_replace_all(args.as_slice()) - .unwrap() - .as_list::(), - criterion::black_box(&expected_array) + array_replace_all( + list_array.clone(), + from_array.clone(), + to_array.clone() + ), + *criterion::black_box(&expected_array) ) }) }); diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fe3397b1af52..79cd6a24ce39 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -102,14 +102,6 @@ pub enum BuiltinScalarFunction { /// cot Cot, - // array functions - /// array_replace - ArrayReplace, - /// array_replace_n - ArrayReplaceN, - /// array_replace_all - ArrayReplaceAll, - // string functions /// ascii Ascii, @@ -262,9 +254,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -322,9 +311,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Ascii => Ok(Int32), BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") @@ -477,11 +463,6 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), - BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), - BuiltinScalarFunction::ArrayReplaceAll => { - Signature::any(3, self.volatility()) - } BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) @@ -779,15 +760,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Levenshtein => &["levenshtein"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], BuiltinScalarFunction::FindInSet => &["find_in_set"], - - // hashing functions - BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], - BuiltinScalarFunction::ArrayReplaceN => { - &["array_replace_n", "list_replace_n"] - } - BuiltinScalarFunction::ArrayReplaceAll => { - &["array_replace_all", "list_replace_all"] - } BuiltinScalarFunction::OverLay => &["overlay"], } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c5ad2a9b3ce4..b76164a1c83c 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -584,25 +584,6 @@ scalar_expr!( scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); -scalar_expr!( - ArrayReplace, - array_replace, - array from to, - "replaces the first occurrence of the specified element with another specified element." -); -scalar_expr!( - ArrayReplaceN, - array_replace_n, - array from to max, - "replaces the first `max` occurrences of the specified element with another specified element." -); -scalar_expr!( - ArrayReplaceAll, - array_replace_all, - array from to, - "replaces all occurrences of the specified element with another specified element." -); - // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); scalar_expr!( @@ -1145,11 +1126,6 @@ mod test { test_scalar_expr!(Translate, translate, string, from, to); test_scalar_expr!(Trim, trim, string); test_scalar_expr!(Upper, upper, string); - - test_scalar_expr!(ArrayReplace, array_replace, array, from, to); - test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max); - test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to); - test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 99239ffb3bdc..80c0e5e18768 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -40,6 +40,7 @@ path = "src/lib.rs" arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-array/src/core.rs b/datafusion/functions-array/src/core.rs index 4c84b7018c99..fdd127cc3f32 100644 --- a/datafusion/functions-array/src/core.rs +++ b/datafusion/functions-array/src/core.rs @@ -96,7 +96,7 @@ impl ScalarUDFImpl for MakeArray { } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(make_array_inner)(args) } diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 2c19dfad6222..fb16acdef2bd 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -36,6 +36,7 @@ mod extract; mod kernels; mod position; mod remove; +mod replace; mod rewrite; mod set_ops; mod udf; @@ -66,6 +67,9 @@ pub mod expr_fn { pub use super::remove::array_remove; pub use super::remove::array_remove_all; pub use super::remove::array_remove_n; + pub use super::replace::array_replace; + pub use super::replace::array_replace_all; + pub use super::replace::array_replace_n; pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; @@ -120,8 +124,11 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { position::array_position_udf(), position::array_positions_udf(), remove::array_remove_udf(), - remove::array_remove_n_udf(), remove::array_remove_all_udf(), + remove::array_remove_n_udf(), + replace::array_replace_n_udf(), + replace::array_replace_all_udf(), + replace::array_replace_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index 4988e0ded106..627cf3cb0cf0 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -27,8 +27,7 @@ use std::sync::Arc; use arrow_array::types::UInt64Type; use arrow_array::{ - Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, UInt64Array, + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, }; use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, @@ -36,6 +35,8 @@ use datafusion_common::cast::{ use datafusion_common::{exec_err, internal_err}; use itertools::Itertools; +use crate::utils::compare_element_to_list; + make_udf_function!( ArrayPosition, array_position, @@ -173,107 +174,6 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. -/// -/// # Arguments -/// -/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. -/// -/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. -/// -/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. -/// -/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. -/// -/// # Returns -/// -/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. -/// -/// # Example -/// -/// ```text -/// compare_element_to_list( -/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] -/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] -/// -/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] -/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] -/// ) -/// ``` -fn compare_element_to_list( - list_array_row: &dyn Array, - element_array: &dyn Array, - row_index: usize, - eq: bool, -) -> datafusion_common::Result { - if list_array_row.data_type() != element_array.data_type() { - return exec_err!( - "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", - list_array_row.data_type(), - element_array.data_type() - ); - } - - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let res = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - DataType::LargeList(_) => { - // compare each element of the from array - let element_array_row_inner = - as_large_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_large_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so we can compare NULL - if eq { - arrow::compute::kernels::cmp::not_distinct(&list_array_row, &element_arr)? - } else { - arrow::compute::kernels::cmp::distinct(&list_array_row, &element_arr)? - } - } - }; - - Ok(res) -} - make_udf_function!( ArrayPositions, array_positions, diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs new file mode 100644 index 000000000000..8ff65d315431 --- /dev/null +++ b/datafusion/functions-array/src/replace.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array functions. + +use arrow::array::{ + Array, ArrayRef, AsArray, Capacities, MutableArrayData, OffsetSizeTrait, +}; +use arrow::datatypes::DataType; + +use arrow_array::GenericListArray; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_schema::Field; +use datafusion_common::cast::as_int64_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::Expr; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::compare_element_to_list; +use crate::utils::make_scalar_function; + +use std::any::Any; +use std::sync::Arc; + +// Create static instances of ScalarUDFs for each function +make_udf_function!(ArrayReplace, + array_replace, + array from to, + "replaces the first occurrence of the specified element with another specified element.", + array_replace_udf +); +make_udf_function!(ArrayReplaceN, + array_replace_n, + array from to max, + "replaces the first `max` occurrences of the specified element with another specified element.", + array_replace_n_udf +); +make_udf_function!(ArrayReplaceAll, + array_replace_all, + array from to, + "replaces all occurrences of the specified element with another specified element.", + array_replace_all_udf +); + +#[derive(Debug)] +pub(super) struct ArrayReplace { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplace { + pub fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + aliases: vec![String::from("array_replace"), String::from("list_replace")], + } + } +} + +impl ScalarUDFImpl for ArrayReplace { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_replace" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_replace_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(super) struct ArrayReplaceN { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplaceN { + pub fn new() -> Self { + Self { + signature: Signature::any(4, Volatility::Immutable), + aliases: vec![ + String::from("array_replace_n"), + String::from("list_replace_n"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayReplaceN { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_replace_n" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_replace_n_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(super) struct ArrayReplaceAll { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplaceAll { + pub fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + aliases: vec![ + String::from("array_replace_all"), + String::from("list_replace_all"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayReplaceAll { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_replace_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_replace_all_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// (\[`ListArray`\] of \[`ListArray`\]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &GenericListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array + let mut offsets: Vec = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut valid = BooleanBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; + } + + let start = offset_window[0]; + let end = offset_window[1]; + + let list_array_row = list_array.value(row_index); + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; + + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let n = arr_n[row_index]; + let mut counter = 0; + + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); + if let Some(true) = to_replace { + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); + break; + } + } else { + // copy original data for false / null matches + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); + } + } + + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), + )?)) +} + +pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } +} + +pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } +} + +pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + + // replace all occurrences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } +} diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index ad613163c6af..9589cb05fe9b 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::{array::ArrayRef, datatypes::DataType}; + use arrow_array::{ Array, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, }; @@ -27,6 +28,7 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { @@ -202,9 +204,9 @@ pub(crate) fn compare_element_to_list( let element_arr = Scalar::new(element_array_row); // use not_distinct so we can compare NULL if eq { - arrow::compute::kernels::cmp::not_distinct(&list_array_row, &element_arr)? + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? } else { - arrow::compute::kernels::cmp::distinct(&list_array_row, &element_arr)? + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? } } }; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs deleted file mode 100644 index c3c0f4c82282..000000000000 --- a/datafusion/physical-expr/src/array_expressions.rs +++ /dev/null @@ -1,423 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Array expressions - -use std::sync::Arc; - -use arrow::array::*; -use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field}; -use arrow_buffer::NullBuffer; - -use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::{exec_err, plan_err, Result}; - -/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. -/// -/// # Arguments -/// -/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. -/// -/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. -/// -/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. -/// -/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. -/// -/// # Returns -/// -/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. -/// -/// # Example -/// -/// ```text -/// compare_element_to_list( -/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] -/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] -/// -/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] -/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] -/// ) -/// ``` -fn compare_element_to_list( - list_array_row: &dyn Array, - element_array: &dyn Array, - row_index: usize, - eq: bool, -) -> Result { - if list_array_row.data_type() != element_array.data_type() { - return exec_err!( - "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", - list_array_row.data_type(), - element_array.data_type() - ); - } - - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let res = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - DataType::LargeList(_) => { - // compare each element of the from array - let element_array_row_inner = - as_large_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_large_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so we can compare NULL - if eq { - arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? - } else { - arrow_ord::cmp::distinct(&list_array_row, &element_arr)? - } - } - }; - - Ok(res) -} - -/// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` or 'LargeListArray' depending on the offset size. -/// -/// # Example (non nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are non nested -/// would return a single new `ListArray`, where each row was a list -/// of 2 elements: -/// -/// ```text -/// ┌─────────┐ ┌─────────┐ ┌──────────────┐ -/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │ -/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │ -/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │ -/// └─────────┘ └─────────┘ └──────────────┘ -/// col1 col2 output -/// ``` -/// -/// # Example (nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are lists -/// would return a single new `ListArray`, where each row was a list -/// of the corresponding elements of col1 and col2. -/// -/// ``` text -/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ -/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │ -/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │ -/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │ -/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │ -/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │ -/// └──────────────┘ └──────────────┘ └─────────────────────────────┘ -/// col1 col2 output -/// ``` -fn array_array( - args: &[ArrayRef], - data_type: DataType, -) -> Result { - // do not accept 0 arguments. - if args.is_empty() { - return plan_err!("Array requires at least one argument"); - } - - let mut data = vec![]; - let mut total_len = 0; - for arg in args { - let arg_data = if arg.as_any().is::() { - ArrayData::new_empty(&data_type) - } else { - arg.to_data() - }; - total_len += arg_data.len(); - data.push(arg_data); - } - - let mut offsets: Vec = Vec::with_capacity(total_len); - offsets.push(O::usize_as(0)); - - let capacity = Capacities::Array(total_len); - let data_ref = data.iter().collect::>(); - let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); - - let num_rows = args[0].len(); - for row_idx in 0..num_rows { - for (arr_idx, arg) in args.iter().enumerate() { - if !arg.as_any().is::() - && !arg.is_null(row_idx) - && arg.is_valid(row_idx) - { - mutable.extend(arr_idx, row_idx, row_idx + 1); - } else { - mutable.extend_nulls(1); - } - } - offsets.push(O::usize_as(mutable.len())); - } - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type, true)), - OffsetBuffer::new(offsets.into()), - arrow_array::make_array(data), - None, - )?)) -} - -/// `make_array` SQL function -pub fn make_array(arrays: &[ArrayRef]) -> Result { - let mut data_type = DataType::Null; - for arg in arrays { - let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&DataType::Null) { - data_type = arg_data_type.clone(); - break; - } - } - - match data_type { - // Either an empty array or all nulls: - DataType::Null => { - let array = - new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); - Ok(Arc::new(array_into_list_array(array))) - } - DataType::LargeList(..) => array_array::(arrays, data_type), - _ => array_array::(arrays, data_type), - } -} - -/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences -/// of `from_array[i]`, `to_array[i]`. -/// -/// The type of each **element** in `list_array` must be the same as the type of -/// `from_array` and `to_array`. This function also handles nested arrays -/// ([`ListArray`] of [`ListArray`]s) -/// -/// For example, when called to replace a list array (where each element is a -/// list of int32s, the second and third argument are int32 arrays, and the -/// fourth argument is the number of occurrences to replace -/// -/// ```text -/// general_replace( -/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) -/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) -/// ) -/// ``` -fn general_replace( - list_array: &GenericListArray, - from_array: &ArrayRef, - to_array: &ArrayRef, - arr_n: Vec, -) -> Result { - // Build up the offsets for the final output array - let mut offsets: Vec = vec![O::usize_as(0)]; - let values = list_array.values(); - let original_data = values.to_data(); - let to_data = to_array.to_data(); - let capacity = Capacities::Array(original_data.len()); - - // First array is the original array, second array is the element to replace with. - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &to_data], - false, - capacity, - ); - - let mut valid = BooleanBufferBuilder::new(list_array.len()); - - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - if list_array.is_null(row_index) { - offsets.push(offsets[row_index]); - valid.append(false); - continue; - } - - let start = offset_window[0]; - let end = offset_window[1]; - - let list_array_row = list_array.value(row_index); - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let eq_array = - compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - - let original_idx = O::usize_as(0); - let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; - let mut counter = 0; - - // All elements are false, no need to replace, just copy original data - if eq_array.false_count() == eq_array.len() { - mutable.extend( - original_idx.to_usize().unwrap(), - start.to_usize().unwrap(), - end.to_usize().unwrap(), - ); - offsets.push(offsets[row_index] + (end - start)); - valid.append(true); - continue; - } - - for (i, to_replace) in eq_array.iter().enumerate() { - let i = O::usize_as(i); - if let Some(true) = to_replace { - mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); - counter += 1; - if counter == n { - // copy original data for any matches past n - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - end.to_usize().unwrap(), - ); - break; - } - } else { - // copy original data for false / null matches - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - ); - } - } - - offsets.push(offsets[row_index] + (end - start)); - valid.append(true); - } - - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - Some(NullBuffer::new(valid.finish())), - )?)) -} - -pub fn array_replace(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace expects three arguments"); - } - - // replace at most one occurence for each element - let arr_n = vec![1; args[0].len()]; - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => exec_err!("array_replace does not support type '{array_type:?}'."), - } -} - -pub fn array_replace_n(args: &[ArrayRef]) -> Result { - if args.len() != 4 { - return exec_err!("array_replace_n expects four arguments"); - } - - // replace the specified number of occurences - let arr_n = as_int64_array(&args[3])?.values().to_vec(); - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => { - exec_err!("array_replace_n does not support type '{array_type:?}'.") - } - } -} - -pub fn array_replace_all(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace_all expects three arguments"); - } - - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; args[0].len()]; - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => { - exec_err!("array_replace_all does not support type '{array_type:?}'.") - } - } -} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 994c17309ec0..c6c185e002f0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -32,8 +32,8 @@ use crate::sort_properties::SortProperties; use crate::{ - array_expressions, conditional_expressions, math_expressions, string_expressions, - PhysicalExpr, ScalarFunctionExpr, + conditional_expressions, math_expressions, string_expressions, PhysicalExpr, + ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, @@ -253,18 +253,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Cot => { Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } - - // array functions - BuiltinScalarFunction::ArrayReplace => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace)(args) - }), - BuiltinScalarFunction::ArrayReplaceN => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace_n)(args) - }), - BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace_all)(args) - }), - // string functions BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index e8b80ee4e1e6..1791a6ed60b2 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,6 @@ pub mod aggregate; pub mod analysis; -pub mod array_expressions; pub mod binary_map; pub mod conditional_expressions; pub mod equivalence; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 597094758584..6879f70cd05c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -637,7 +637,7 @@ enum ScalarFunction { // 93 was ArrayPositions // 94 was ArrayPrepend // 95 was ArrayRemove - ArrayReplace = 96; + // 96 was ArrayReplace // 97 was ArrayToString // 98 was Cardinality // 99 was ArrayElement @@ -647,9 +647,9 @@ enum ScalarFunction { // 105 was ArrayHasAny // 106 was ArrayHasAll // 107 was ArrayRemoveN - ArrayReplaceN = 108; + // 108 was ArrayReplaceN // 109 was ArrayRemoveAll - ArrayReplaceAll = 110; + // 110 was ArrayReplaceAll Nanvl = 111; // 112 was Flatten // 113 was IsNan diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cb9633338e8f..75c135fd01b4 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22951,10 +22951,7 @@ impl serde::Serialize for ScalarFunction { Self::Factorial => "Factorial", Self::Lcm => "Lcm", Self::Gcd => "Gcd", - Self::ArrayReplace => "ArrayReplace", Self::Cot => "Cot", - Self::ArrayReplaceN => "ArrayReplaceN", - Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", Self::OverLay => "OverLay", @@ -23032,10 +23029,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial", "Lcm", "Gcd", - "ArrayReplace", "Cot", - "ArrayReplaceN", - "ArrayReplaceAll", "Nanvl", "Iszero", "OverLay", @@ -23142,10 +23136,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial" => Ok(ScalarFunction::Factorial), "Lcm" => Ok(ScalarFunction::Lcm), "Gcd" => Ok(ScalarFunction::Gcd), - "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), "Cot" => Ok(ScalarFunction::Cot), - "ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN), - "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), "OverLay" => Ok(ScalarFunction::OverLay), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f5ef6c1f74f0..c9cc4a9b073b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2910,7 +2910,7 @@ pub enum ScalarFunction { /// 93 was ArrayPositions /// 94 was ArrayPrepend /// 95 was ArrayRemove - ArrayReplace = 96, + /// 96 was ArrayReplace /// 97 was ArrayToString /// 98 was Cardinality /// 99 was ArrayElement @@ -2920,9 +2920,9 @@ pub enum ScalarFunction { /// 105 was ArrayHasAny /// 106 was ArrayHasAll /// 107 was ArrayRemoveN - ArrayReplaceN = 108, + /// 108 was ArrayReplaceN /// 109 was ArrayRemoveAll - ArrayReplaceAll = 110, + /// 110 was ArrayReplaceAll Nanvl = 111, /// 112 was Flatten /// 113 was IsNan @@ -3019,10 +3019,7 @@ impl ScalarFunction { ScalarFunction::Factorial => "Factorial", ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", - ScalarFunction::ArrayReplace => "ArrayReplace", ScalarFunction::Cot => "Cot", - ScalarFunction::ArrayReplaceN => "ArrayReplaceN", - ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", ScalarFunction::OverLay => "OverLay", @@ -3094,10 +3091,7 @@ impl ScalarFunction { "Factorial" => Some(Self::Factorial), "Lcm" => Some(Self::Lcm), "Gcd" => Some(Self::Gcd), - "ArrayReplace" => Some(Self::ArrayReplace), "Cot" => Some(Self::Cot), - "ArrayReplaceN" => Some(Self::ArrayReplaceN), - "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), "OverLay" => Some(Self::OverLay), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3822b74bc18c..06aab16edd57 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,9 +48,9 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, array_replace, array_replace_all, array_replace_n, ascii, asinh, atan, atan2, - atanh, bit_length, btrim, cbrt, ceil, character_length, chr, coalesce, concat_expr, - concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, ascii, asinh, atan, atan2, atanh, bit_length, btrim, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -466,9 +466,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Trim => Self::Trim, ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, - ScalarFunction::ArrayReplace => Self::ArrayReplace, - ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, - ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::Ascii => Self::Ascii, @@ -1362,22 +1359,6 @@ pub fn parse_expr( ScalarFunction::Acosh => { Ok(acosh(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::ArrayReplace => Ok(array_replace( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::ArrayReplaceN => Ok(array_replace_n( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - parse_expr(&args[3], registry, codec)?, - )), - ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7a17d2a2b405..478f7c779552 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1453,9 +1453,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Trim => Self::Trim, BuiltinScalarFunction::Ltrim => Self::Ltrim, BuiltinScalarFunction::Rtrim => Self::Rtrim, - BuiltinScalarFunction::ArrayReplace => Self::ArrayReplace, - BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, - BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, BuiltinScalarFunction::Ascii => Self::Ascii, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 479f80fbdddf..93de560dbee5 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -605,6 +605,14 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(3), lit(3), lit(2), lit(3), lit(1)]), lit(3), ), + array_replace(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + array_replace_n( + make_array(vec![lit(1), lit(2), lit(3)]), + lit(2), + lit(4), + lit(1), + ), + array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), ]; // ensure expressions created with the expr api can be round tripped From 2c7cf5f41af8f74feec8c8581732ef0fb5003008 Mon Sep 17 00:00:00 2001 From: InventiveCoder <163831412+InventiveCoder@users.noreply.github.com> Date: Tue, 19 Mar 2024 02:21:46 +0800 Subject: [PATCH 03/35] chore: remove repetitive words (#9673) Signed-off-by: InventiveCoder --- datafusion/common/src/stats.rs | 2 +- datafusion/core/src/datasource/file_format/mod.rs | 2 +- .../src/physical_optimizer/enforce_distribution.rs | 2 +- .../src/physical_optimizer/output_requirements.rs | 2 +- .../src/physical_optimizer/projection_pushdown.rs | 14 +++++++------- datafusion/physical-expr/src/binary_map.rs | 2 +- .../physical-plan/src/aggregates/order/mod.rs | 2 +- datafusion/sql/src/expr/arrow_cast.rs | 2 +- .../sqllogictest/test_files/create_function.slt | 2 +- datafusion/sqllogictest/test_files/limit.slt | 2 +- dev/changelog/13.0.0.md | 2 +- dev/changelog/7.0.0.md | 2 +- docs/source/contributor-guide/communication.md | 2 +- docs/source/library-user-guide/adding-udfs.md | 2 +- 14 files changed, 20 insertions(+), 20 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index a10e05a55c64..6cefef8d0eb5 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -221,7 +221,7 @@ pub struct Statistics { /// Total bytes of the table rows. pub total_byte_size: Precision, /// Statistics on a column level. It contains a [`ColumnStatistics`] for - /// each field in the schema of the the table to which the [`Statistics`] refer. + /// each field in the schema of the table to which the [`Statistics`] refer. pub column_statistics: Vec, } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 72dc289d4b64..5ee0f7186703 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -49,7 +49,7 @@ use object_store::{ObjectMeta, ObjectStore}; /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across -/// providers that support the the same file formats. +/// providers that support the same file formats. /// /// [`TableProvider`]: crate::datasource::provider::TableProvider #[async_trait] diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 54fe6e8406fd..0740a8d2cdbc 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -392,7 +392,7 @@ fn adjust_input_keys_ordering( let expr = proj.expr(); // For Projection, we need to transform the requirements to the columns before the Projection // And then to push down the requirements - // Construct a mapping from new name to the the orginal Column + // Construct a mapping from new name to the orginal Column let new_required = map_columns_before_projection(&requirements.data, expr); if new_required.len() == requirements.data.len() { requirements.children[0].data = new_required; diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index bd71b3e8ed80..bf010a5e39d8 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -216,7 +216,7 @@ impl PhysicalOptimizerRule for OutputRequirements { } } -/// This functions adds ancillary `OutputRequirementExec` to the the physical plan, so that +/// This functions adds ancillary `OutputRequirementExec` to the physical plan, so that /// global requirements are not lost during optimization. fn require_top_ordering(plan: Arc) -> Result> { let (new_plan, is_changed) = require_top_ordering_helper(plan)?; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e8f3bf01ecaa..ab5611597472 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -322,7 +322,7 @@ fn try_swapping_with_output_req( projection: &ProjectionExec, output_req: &OutputRequirementExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -372,7 +372,7 @@ fn try_swapping_with_output_req( fn try_swapping_with_coalesce_partitions( projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -387,7 +387,7 @@ fn try_swapping_with_filter( projection: &ProjectionExec, filter: &FilterExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -412,7 +412,7 @@ fn try_swapping_with_repartition( projection: &ProjectionExec, repartition: &RepartitionExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down. + // If the projection does not narrow the schema, we should not try to push it down. if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -454,7 +454,7 @@ fn try_swapping_with_sort( projection: &ProjectionExec, sort: &SortExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down. + // If the projection does not narrow the schema, we should not try to push it down. if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -1082,7 +1082,7 @@ fn join_table_borders( (far_right_left_col_ind, far_left_right_col_ind) } -/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// Tries to update the equi-join `Column`'s of a join as if the input of /// the join was replaced by a projection. fn update_join_on( proj_left_exprs: &[(Column, String)], @@ -1152,7 +1152,7 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// Tries to update the column indices of a [`JoinFilter`] as if the input of /// the join was replaced by a projection. fn update_join_filter( projection_left_exprs: &[(Column, String)], diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr/src/binary_map.rs index b661f0a74148..6c3a452a8611 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr/src/binary_map.rs @@ -280,7 +280,7 @@ where /// # Returns /// /// The payload value for the entry, either the existing value or - /// the the newly inserted value + /// the newly inserted value /// /// # Safety: /// diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 4f1914b12c96..556103e1e222 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -40,7 +40,7 @@ pub(crate) enum GroupOrdering { } impl GroupOrdering { - /// Create a `GroupOrdering` for the the specified ordering + /// Create a `GroupOrdering` for the specified ordering pub fn try_new( input_schema: &Schema, mode: &InputOrderMode, diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 9a0d61f41c01..a75cdf9e3c6b 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -76,7 +76,7 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// Parses `str` into a `DataType`. /// -/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// `parse_data_type` is the reverse of [`DataType`]'s `Display` /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// diff --git a/datafusion/sqllogictest/test_files/create_function.slt b/datafusion/sqllogictest/test_files/create_function.slt index baa40ac64afc..4f0c53c36ca1 100644 --- a/datafusion/sqllogictest/test_files/create_function.slt +++ b/datafusion/sqllogictest/test_files/create_function.slt @@ -47,7 +47,7 @@ select abs(-1); statement ok DROP FUNCTION abs; -# now the the query errors +# now the query errors query error Invalid function 'abs'. select abs(-1); diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 92093ba13eba..0d98c41d0028 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -320,7 +320,7 @@ SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); 0 # The aggregate does not need to be computed because the input statistics are exact and -# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET). +# the number of rows is less than or equal to the "fetch+skip" value (LIMIT+OFFSET). query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); ---- diff --git a/dev/changelog/13.0.0.md b/dev/changelog/13.0.0.md index 0f35903e2600..14b42a052ef9 100644 --- a/dev/changelog/13.0.0.md +++ b/dev/changelog/13.0.0.md @@ -87,7 +87,7 @@ - Optimizer rule 'projection_push_down' failed due to unexpected error: Error during planning: Aggregate schema has wrong number of fields. Expected 3 got 8 [\#3704](https://github.com/apache/arrow-datafusion/issues/3704) - Optimizer regressions in `unwrap_cast_in_comparison` [\#3690](https://github.com/apache/arrow-datafusion/issues/3690) - Internal error when evaluating a predicate = "The type of Dictionary\(Int16, Utf8\) = Int64 of binary physical should be same" [\#3685](https://github.com/apache/arrow-datafusion/issues/3685) -- Specialized regexp_replace should early-abort when the the input arrays are empty [\#3647](https://github.com/apache/arrow-datafusion/issues/3647) +- Specialized regexp_replace should early-abort when the input arrays are empty [\#3647](https://github.com/apache/arrow-datafusion/issues/3647) - Internal error: Failed to coerce types Decimal128\(10, 2\) and Boolean in BETWEEN expression [\#3646](https://github.com/apache/arrow-datafusion/issues/3646) - Internal error: Failed to coerce types Decimal128\(10, 2\) and Boolean in BETWEEN expression [\#3645](https://github.com/apache/arrow-datafusion/issues/3645) - Type coercion error: The type of Boolean AND Decimal128\(10, 2\) of binary physical should be same [\#3644](https://github.com/apache/arrow-datafusion/issues/3644) diff --git a/dev/changelog/7.0.0.md b/dev/changelog/7.0.0.md index e63c2a4455c9..4d2606d7bfbe 100644 --- a/dev/changelog/7.0.0.md +++ b/dev/changelog/7.0.0.md @@ -56,7 +56,7 @@ - Keep all datafusion's packages up to date with Dependabot [\#1472](https://github.com/apache/arrow-datafusion/issues/1472) - ExecutionContext support init ExecutionContextState with `new(state: Arc>)` method [\#1439](https://github.com/apache/arrow-datafusion/issues/1439) - support the decimal scalar value [\#1393](https://github.com/apache/arrow-datafusion/issues/1393) -- Documentation for using scalar functions with the the DataFrame API [\#1364](https://github.com/apache/arrow-datafusion/issues/1364) +- Documentation for using scalar functions with the DataFrame API [\#1364](https://github.com/apache/arrow-datafusion/issues/1364) - Support `boolean == boolean` and `boolean != boolean` operators [\#1159](https://github.com/apache/arrow-datafusion/issues/1159) - Support DataType::Decimal\(15, 2\) in TPC-H benchmark [\#174](https://github.com/apache/arrow-datafusion/issues/174) - Make `MemoryStream` public [\#150](https://github.com/apache/arrow-datafusion/issues/150) diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 8678aa534baf..7b5e71bc3a1c 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -44,7 +44,7 @@ request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https ## Mailing list We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other -than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. +than the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index f433e026e0a2..ad210724103d 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -204,7 +204,7 @@ let df = ctx.sql(&sql).await.unwrap(); ## Adding a Window UDF -Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. +Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the proximal rows is helpful, but adds some complexity to the implementation. For example, we will declare a user defined window function that computes a moving average. From c0a21b28c7dadd7d3e1db1fbe2433735a2b65d5a Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 18 Mar 2024 14:23:16 -0400 Subject: [PATCH 04/35] Update example-usage.md to remove reference to simd and rust nightly. (#9677) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. --- docs/source/user-guide/example-usage.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 1c5c8f49a16a..c5eefbdaf156 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -240,17 +240,11 @@ async fn main() -> datafusion::error::Result<()> { } ``` -Finally, in order to build with the `simd` optimization `cargo nightly` is required. - -```shell -rustup toolchain install nightly -``` - Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally with `native` or at least `avx2`. ```shell -RUSTFLAGS='-C target-cpu=native' cargo +nightly run --release +RUSTFLAGS='-C target-cpu=native' cargo run --release ``` ## Enable backtraces From 35ff7a66c0e2579489e1408bb426fe4444f6ce2e Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 18 Mar 2024 21:23:31 +0300 Subject: [PATCH 05/35] Minor changes (#9674) --- .../physical-expr/src/window/nth_value.rs | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index e913f39333f9..9de71c2d604c 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -225,31 +225,38 @@ impl PartitionEvaluator for NthValueEvaluator { } // Extract valid indices if ignoring nulls. - let (slice, valid_indices) = if self.ignore_nulls { + let valid_indices = if self.ignore_nulls { + // Calculate valid indices, inside the window frame boundaries let slice = arr.slice(range.start, n_range); - let valid_indices = - slice.nulls().unwrap().valid_indices().collect::>(); + let valid_indices = slice + .nulls() + .map(|nulls| { + nulls + .valid_indices() + // Add offset `range.start` to valid indices, to point correct index in the original arr. + .map(|idx| idx + range.start) + .collect::>() + }) + .unwrap_or_default(); if valid_indices.is_empty() { return ScalarValue::try_from(arr.data_type()); } - (Some(slice), Some(valid_indices)) + Some(valid_indices) } else { - (None, None) + None }; match self.state.kind { NthValueKind::First => { - if let Some(slice) = &slice { - let valid_indices = valid_indices.unwrap(); - ScalarValue::try_from_array(slice, valid_indices[0]) + if let Some(valid_indices) = &valid_indices { + ScalarValue::try_from_array(arr, valid_indices[0]) } else { ScalarValue::try_from_array(arr, range.start) } } NthValueKind::Last => { - if let Some(slice) = &slice { - let valid_indices = valid_indices.unwrap(); + if let Some(valid_indices) = &valid_indices { ScalarValue::try_from_array( - slice, + arr, valid_indices[valid_indices.len() - 1], ) } else { @@ -264,15 +271,11 @@ impl PartitionEvaluator for NthValueEvaluator { if index >= n_range { // Outside the range, return NULL: ScalarValue::try_from(arr.data_type()) - } else if self.ignore_nulls { - let valid_indices = valid_indices.unwrap(); + } else if let Some(valid_indices) = valid_indices { if index >= valid_indices.len() { return ScalarValue::try_from(arr.data_type()); } - ScalarValue::try_from_array( - &slice.unwrap(), - valid_indices[index], - ) + ScalarValue::try_from_array(&arr, valid_indices[index]) } else { ScalarValue::try_from_array(arr, range.start + index) } @@ -282,14 +285,13 @@ impl PartitionEvaluator for NthValueEvaluator { if n_range < reverse_index { // Outside the range, return NULL: ScalarValue::try_from(arr.data_type()) - } else if self.ignore_nulls { - let valid_indices = valid_indices.unwrap(); + } else if let Some(valid_indices) = valid_indices { if reverse_index > valid_indices.len() { return ScalarValue::try_from(arr.data_type()); } let new_index = valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&slice.unwrap(), new_index) + ScalarValue::try_from_array(&arr, new_index) } else { ScalarValue::try_from_array( arr, From 4687a2f793019deb199f1759f5171730a6434189 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 18 Mar 2024 13:59:00 -0700 Subject: [PATCH 06/35] minor: Remove deprecated methods (#9627) * minor: remove deprecared code * Remove deprecated test * docs --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/dfschema.rs | 51 +++----------- datafusion/core/src/dataframe/mod.rs | 10 --- datafusion/core/src/datasource/listing/url.rs | 33 --------- datafusion/core/src/execution/context/mod.rs | 43 ------------ datafusion/execution/src/config.rs | 16 +---- datafusion/execution/src/task.rs | 45 ++---------- datafusion/expr/src/aggregate_function.rs | 22 ------ datafusion/expr/src/expr.rs | 28 -------- datafusion/expr/src/expr_rewriter/mod.rs | 43 ------------ datafusion/expr/src/function.rs | 25 +------ datafusion/expr/src/logical_plan/plan.rs | 70 ------------------- datafusion/physical-plan/src/common.rs | 6 -- datafusion/physical-plan/src/sorts/sort.rs | 27 ------- 13 files changed, 18 insertions(+), 401 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 2642032c9a04..597507a044a2 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -97,10 +97,11 @@ pub type DFSchemaRef = Arc; /// ```rust /// use datafusion_common::{DFSchema, DFField}; /// use arrow_schema::Schema; +/// use std::collections::HashMap; /// -/// let df_schema = DFSchema::new(vec![ +/// let df_schema = DFSchema::new_with_metadata(vec![ /// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), -/// ]).unwrap(); +/// ], HashMap::new()).unwrap(); /// let schema = Schema::from(df_schema); /// assert_eq!(schema.fields().len(), 1); /// ``` @@ -124,12 +125,6 @@ impl DFSchema { } } - #[deprecated(since = "7.0.0", note = "please use `new_with_metadata` instead")] - /// Create a new `DFSchema` - pub fn new(fields: Vec) -> Result { - Self::new_with_metadata(fields, HashMap::new()) - } - /// Create a new `DFSchema` pub fn new_with_metadata( fields: Vec, @@ -251,32 +246,6 @@ impl DFSchema { &self.fields[i] } - #[deprecated(since = "8.0.0", note = "please use `index_of_column_by_name` instead")] - /// Find the index of the column with the given unqualified name - pub fn index_of(&self, name: &str) -> Result { - for i in 0..self.fields.len() { - if self.fields[i].name() == name { - return Ok(i); - } else { - // Now that `index_of` is deprecated an error is thrown if - // a fully qualified field name is provided. - match &self.fields[i].qualifier { - Some(qualifier) => { - if (qualifier.to_string() + "." + self.fields[i].name()) == name { - return _plan_err!( - "Fully qualified field name '{name}' was supplied to `index_of` \ - which is deprecated. Please use `index_of_column_by_name` instead" - ); - } - } - None => (), - } - } - } - - Err(unqualified_field_not_found(name, self)) - } - pub fn index_of_column_by_name( &self, qualifier: Option<&TableReference>, @@ -1146,13 +1115,10 @@ mod tests { Ok(()) } - #[allow(deprecated)] #[test] fn helpful_error_messages() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let expected_help = "Valid fields are t1.c0, t1.c1."; - // Pertinent message parts - let expected_err_msg = "Fully qualified field name 't1.c0'"; assert_contains!( schema .field_with_qualified_name(&TableReference::bare("x"), "y") @@ -1167,11 +1133,12 @@ mod tests { .to_string(), expected_help ); - assert_contains!(schema.index_of("y").unwrap_err().to_string(), expected_help); - assert_contains!( - schema.index_of("t1.c0").unwrap_err().to_string(), - expected_err_msg - ); + assert!(schema.index_of_column_by_name(None, "y").unwrap().is_none()); + assert!(schema + .index_of_column_by_name(None, "t1.c0") + .unwrap() + .is_none()); + Ok(()) } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5f192b83fdd9..25830401571d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1001,16 +1001,6 @@ impl DataFrame { Arc::new(DataFrameTableProvider { plan: self.plan }) } - /// Return the optimized logical plan represented by this DataFrame. - /// - /// Note: This method should not be used outside testing, as it loses the snapshot - /// of the [`SessionState`] attached to this [`DataFrame`] and consequently subsequent - /// operations may take place against a different state - #[deprecated(since = "23.0.0", note = "Use DataFrame::into_optimized_plan")] - pub fn to_logical_plan(self) -> Result { - self.into_optimized_plan() - } - /// Return a DataFrame with the explanation of its plan so far. /// /// if `analyze` is specified, runs the plan and reports metrics diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index d9149bcc20e0..eb95dc7b1d24 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::fs; - use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; @@ -117,37 +115,6 @@ impl ListingTableUrl { } } - /// Get object store for specified input_url - /// if input_url is actually not a url, we assume it is a local file path - /// if we have a local path, create it if not exists so ListingTableUrl::parse works - #[deprecated(note = "Use parse")] - pub fn parse_create_local_if_not_exists( - s: impl AsRef, - is_directory: bool, - ) -> Result { - let s = s.as_ref(); - let is_valid_url = Url::parse(s).is_ok(); - - match is_valid_url { - true => ListingTableUrl::parse(s), - false => { - let path = std::path::PathBuf::from(s); - if !path.exists() { - if is_directory { - fs::create_dir_all(path)?; - } else { - // ensure parent directory exists - if let Some(parent) = path.parent() { - fs::create_dir_all(parent)?; - } - fs::File::create(path)?; - } - } - ListingTableUrl::parse(s) - } - } - } - /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path #[cfg(not(target_arch = "wasm32"))] fn parse_path(s: &str) -> Result { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 32c1c60ec564..1ac7da465216 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1181,49 +1181,6 @@ impl SessionContext { } } - /// Returns the set of available tables in the default catalog and - /// schema. - /// - /// Use [`table`] to get a specific table. - /// - /// [`table`]: SessionContext::table - #[deprecated( - since = "23.0.0", - note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables" - )] - pub fn tables(&self) -> Result> { - Ok(self - .state - .read() - // a bare reference will always resolve to the default catalog and schema - .schema_for_ref(TableReference::Bare { table: "".into() })? - .table_names() - .iter() - .cloned() - .collect()) - } - - /// Optimizes the logical plan by applying optimizer rules. - #[deprecated( - since = "23.0.0", - note = "Use SessionState::optimize to ensure a consistent state for planning and execution" - )] - pub fn optimize(&self, plan: &LogicalPlan) -> Result { - self.state.read().optimize(plan) - } - - /// Creates a physical plan from a logical plan. - #[deprecated( - since = "23.0.0", - note = "Use SessionState::create_physical_plan or DataFrame::create_physical_plan to ensure a consistent state for planning and execution" - )] - pub async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - ) -> Result> { - self.state().create_physical_plan(logical_plan).await - } - /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 312aef953e9c..360bac71c510 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -434,9 +434,9 @@ impl SessionConfig { /// converted to strings. /// /// Note that this method will eventually be deprecated and - /// replaced by [`config_options`]. + /// replaced by [`options`]. /// - /// [`config_options`]: Self::config_options + /// [`options`]: Self::options pub fn to_props(&self) -> HashMap { let mut map = HashMap::new(); // copy configs from config_options @@ -447,18 +447,6 @@ impl SessionConfig { map } - /// Return a handle to the configuration options. - #[deprecated(since = "21.0.0", note = "use options() instead")] - pub fn config_options(&self) -> &ConfigOptions { - &self.options - } - - /// Return a mutable handle to the configuration options. - #[deprecated(since = "21.0.0", note = "use options_mut() instead")] - pub fn config_options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - /// Add extensions. /// /// Extensions can be used to attach extra data to the session config -- e.g. tracing information or caches. diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index cae410655d10..4216ce95f35e 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -20,10 +20,7 @@ use std::{ sync::Arc, }; -use datafusion_common::{ - config::{ConfigOptions, Extensions}, - plan_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ @@ -102,39 +99,6 @@ impl TaskContext { } } - /// Create a new task context instance, by first copying all - /// name/value pairs from `task_props` into a `SessionConfig`. - #[deprecated( - since = "21.0.0", - note = "Construct SessionConfig and call TaskContext::new() instead" - )] - pub fn try_new( - task_id: String, - session_id: String, - task_props: HashMap, - scalar_functions: HashMap>, - aggregate_functions: HashMap>, - runtime: Arc, - extensions: Extensions, - ) -> Result { - let mut config = ConfigOptions::new().with_extensions(extensions); - for (k, v) in task_props { - config.set(&k, &v)?; - } - let session_config = SessionConfig::from(config); - let window_functions = HashMap::new(); - - Ok(Self::new( - Some(task_id), - session_id, - session_config, - scalar_functions, - aggregate_functions, - window_functions, - runtime, - )) - } - /// Return the SessionConfig associated with this [TaskContext] pub fn session_config(&self) -> &SessionConfig { &self.session_config @@ -160,7 +124,7 @@ impl TaskContext { self.runtime.clone() } - /// Update the [`ConfigOptions`] + /// Update the [`SessionConfig`] pub fn with_session_config(mut self, session_config: SessionConfig) -> Self { self.session_config = session_config; self @@ -229,7 +193,10 @@ impl FunctionRegistry for TaskContext { #[cfg(test)] mod tests { use super::*; - use datafusion_common::{config::ConfigExtension, extensions_options}; + use datafusion_common::{ + config::{ConfigExtension, ConfigOptions, Extensions}, + extensions_options, + }; extensions_options! { struct TestExtension { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 574de3e7082a..85f8c74f3737 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -218,19 +218,6 @@ impl FromStr for AggregateFunction { } } -/// Returns the datatype of the aggregate function. -/// This is used to get the returned data type for aggregate expr. -#[deprecated( - since = "27.0.0", - note = "please use `AggregateFunction::return_type` instead" -)] -pub fn return_type( - fun: &AggregateFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - impl AggregateFunction { /// Returns the datatype of the aggregate function given its argument types /// @@ -328,15 +315,6 @@ pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { avg_sum_type(&coerced_data_types[0]) } -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `AggregateFunction::signature` instead" -)] -pub fn signature(fun: &AggregateFunction) -> Signature { - fun.signature() -} - impl AggregateFunction { /// the signatures supported by the function `fun`. pub fn signature(&self) -> Signature { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0da05d96f67e..7ede4cd8ffc9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -703,27 +703,6 @@ pub fn find_df_window_func(name: &str) -> Option { } } -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunctionDefinition, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunctionDefinition) -> Signature { - fun.signature() -} - // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -887,13 +866,6 @@ impl Expr { create_name(self) } - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. - #[deprecated(since = "14.0.0", note = "please use `display_name` instead")] - pub fn name(&self) -> Result { - self.display_name() - } - /// Returns a full and complete string representation of this expression. pub fn canonical_name(&self) -> String { format!("{self}") diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 357b1aed7dde..ea3ffadda391 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -74,31 +74,6 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { .data() } -/// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions -/// in the `expr` expression tree. -#[deprecated( - since = "20.0.0", - note = "use normalize_col_with_schemas_and_ambiguity_check instead" -)] -#[allow(deprecated)] -pub fn normalize_col_with_schemas( - expr: Expr, - schemas: &[&Arc], - using_columns: &[HashSet], -) -> Result { - expr.transform(&|expr| { - Ok({ - if let Expr::Column(c) = expr { - let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(expr) - } - }) - }) - .data() -} - /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage pub fn normalize_col_with_schemas_and_ambiguity_check( expr: Expr, @@ -398,24 +373,6 @@ mod test { ); } - #[test] - #[allow(deprecated)] - fn normalize_cols_priority() { - let expr = col("a") + col("b"); - // Schemas with multiple matches for column a, first takes priority - let schema_a = make_schema_with_empty_metadata(vec![make_field("tableA", "a")]); - let schema_b = make_schema_with_empty_metadata(vec![make_field("tableB", "b")]); - let schema_a2 = make_schema_with_empty_metadata(vec![make_field("tableA2", "a")]); - let schemas = vec![schema_a2, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); - } - #[test] fn normalize_cols_non_exist() { // test normalizing columns when the name doesn't exist diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index a3760eeb357d..adf4dd3fef20 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,9 +17,7 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{ - Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature, -}; +use crate::{Accumulator, ColumnarValue, PartitionEvaluator}; use arrow::datatypes::DataType; use datafusion_common::Result; use std::sync::Arc; @@ -53,24 +51,3 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; - -/// Returns the datatype of the scalar function -#[deprecated( - since = "27.0.0", - note = "please use `BuiltinScalarFunction::return_type` instead" -)] -pub fn return_type( - fun: &BuiltinScalarFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -/// Return the [`Signature`] supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltinScalarFunction::signature` instead" -)] -pub fn signature(fun: &BuiltinScalarFunction) -> Signature { - fun.signature() -} diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c6f280acb255..08fe3380061f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -217,56 +217,6 @@ impl LogicalPlan { } } - /// Get all meaningful schemas of a plan and its children plan. - #[deprecated(since = "20.0.0")] - pub fn all_schemas(&self) -> Vec<&DFSchemaRef> { - match self { - // return self and children schemas - LogicalPlan::Window(_) - | LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => { - let mut schemas = vec![self.schema()]; - self.inputs().iter().for_each(|input| { - schemas.push(input.schema()); - }); - schemas - } - // just return self.schema() - LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Values(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Union(_) - | LogicalPlan::Extension(_) - | LogicalPlan::TableScan(_) => { - vec![self.schema()] - } - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // return only the schema of the static term - static_term.all_schemas() - } - // return children schemas - LogicalPlan::Limit(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Prepare(_) => { - self.inputs().iter().map(|p| p.schema()).collect() - } - // return empty - LogicalPlan::Statement(_) | LogicalPlan::DescribeTable(_) => vec![], - } - } - /// Returns the (fixed) output schema for explain plans pub fn explain_schema() -> SchemaRef { SchemaRef::new(Schema::new(vec![ @@ -3079,14 +3029,6 @@ digraph { empty_schema: DFSchemaRef, } - impl NoChildExtension { - fn empty() -> Self { - Self { - empty_schema: Arc::new(DFSchema::empty()), - } - } - } - impl UserDefinedLogicalNode for NoChildExtension { fn as_any(&self) -> &dyn std::any::Any { unimplemented!() @@ -3129,18 +3071,6 @@ digraph { } } - #[test] - #[allow(deprecated)] - fn test_extension_all_schemas() { - let plan = LogicalPlan::Extension(Extension { - node: Arc::new(NoChildExtension::empty()), - }); - - let schemas = plan.all_schemas(); - assert_eq!(1, schemas.len()); - assert_eq!(0, schemas[0].fields().len()); - } - #[test] fn test_replace_invalid_placeholder() { // test empty placeholder diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index f4a2cba68e16..59c54199333e 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -349,12 +349,6 @@ pub fn can_project( } } -/// Returns the total number of bytes of memory occupied physically by this batch. -#[deprecated(since = "28.0.0", note = "RecordBatch::get_array_memory_size")] -pub fn batch_byte_size(batch: &RecordBatch) -> usize { - batch.get_array_memory_size() -} - #[cfg(test)] mod tests { use std::ops::Not; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index db352bb2c86f..a80dab058ca6 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -733,16 +733,6 @@ pub struct SortExec { } impl SortExec { - /// Create a new sort execution plan - #[deprecated(since = "22.0.0", note = "use `new` and `with_fetch`")] - pub fn try_new( - expr: Vec, - input: Arc, - fetch: Option, - ) -> Result { - Ok(Self::new(expr, input).with_fetch(fetch)) - } - /// Create a new sort execution plan that produces a single, /// sorted output partition. pub fn new(expr: Vec, input: Arc) -> Self { @@ -758,23 +748,6 @@ impl SortExec { } } - /// Create a new sort execution plan with the option to preserve - /// the partitioning of the input plan - #[deprecated( - since = "22.0.0", - note = "use `new`, `with_fetch` and `with_preserve_partioning` instead" - )] - pub fn new_with_partitioning( - expr: Vec, - input: Arc, - preserve_partitioning: bool, - fetch: Option, - ) -> Self { - Self::new(expr, input) - .with_fetch(fetch) - .with_preserve_partitioning(preserve_partitioning) - } - /// Whether this `SortExec` preserves partitioning of the children pub fn preserve_partitioning(&self) -> bool { self.preserve_partitioning From 2499245f348f2b8fe9777ab7ff7552642c56b4ce Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Mar 2024 17:12:04 -0400 Subject: [PATCH 07/35] Migrate `arrow_cast` to a UDF (#9610) * feat: arrow_cast function as UDF * fix: cargo.lock in datafusion-cli * fix: unwrap arg1 on match arm Co-authored-by: Andrew Lamb * fix: unwrap on matching arms using some * Rewrite to use simplify API * Update error messages * Fix up tests * Update cargo.lock * fix test * fix * Fix merge errors, return error --------- Co-authored-by: Brayan Jules Co-authored-by: Brayan Jules --- datafusion-examples/examples/to_char.rs | 46 +++--- .../core/tests/optimizer_integration.rs | 27 +++- .../user_defined/user_defined_aggregates.rs | 10 +- .../expr => functions/src/core}/arrow_cast.rs | 145 ++++++++++++------ datafusion/functions/src/core/mod.rs | 3 + datafusion/sql/src/expr/function.rs | 8 - datafusion/sql/src/expr/mod.rs | 1 - datafusion/sql/src/lib.rs | 1 - datafusion/sql/tests/sql_integration.rs | 14 +- .../sqllogictest/test_files/arrow_typeof.slt | 9 +- 10 files changed, 155 insertions(+), 109 deletions(-) rename datafusion/{sql/src/expr => functions/src/core}/arrow_cast.rs (90%) diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index e99f69fbcd55..ef616d72cc1c 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -125,14 +125,14 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------+", - "| t.values |", - "+------------+", - "| 2020-09-01 |", - "| 2020-09-02 |", - "| 2020-09-03 |", - "| 2020-09-04 |", - "+------------+", + "+-----------------------------------+", + "| arrow_cast(t.values,Utf8(\"Utf8\")) |", + "+-----------------------------------+", + "| 2020-09-01 |", + "| 2020-09-02 |", + "| 2020-09-03 |", + "| 2020-09-04 |", + "+-----------------------------------+", ], &result ); @@ -146,11 +146,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+-----------------------------------------------------------------+", - "| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", - "+-----------------------------------------------------------------+", - "| 03-08-2023 14:38:50 |", - "+-----------------------------------------------------------------+", + "+-------------------------------------------------------------------------------------------------------------+", + "| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", + "+-------------------------------------------------------------------------------------------------------------+", + "| 03-08-2023 14:38:50 |", + "+-------------------------------------------------------------------------------------------------------------+", ], &result ); @@ -165,11 +165,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------+", - "| to_char(Int64(123456),Utf8(\"pretty\")) |", - "+---------------------------------------+", - "| 1 days 10 hours 17 mins 36 secs |", - "+---------------------------------------+", + "+----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |", + "+----------------------------------------------------------------------------+", + "| 1 days 10 hours 17 mins 36 secs |", + "+----------------------------------------------------------------------------+", ], &result ); @@ -184,11 +184,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------+", - "| to_char(Int64(123456),Utf8(\"iso8601\")) |", - "+----------------------------------------+", - "| PT123456S |", - "+----------------------------------------+", + "+-----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |", + "+-----------------------------------------------------------------------------+", + "| PT123456S |", + "+-----------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index f9696955769e..60010bdddfb8 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +//! Tests for the DataFusion SQL query planner that require functions from the +//! datafusion-functions crate. + use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -42,12 +45,18 @@ fn init() { let _ = env_logger::try_init(); } +#[test] +fn select_arrow_cast() { + let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; + let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ + \n EmptyRelation"; + quick_test(sql, expected); +} #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { let sql = "SELECT col_int32 FROM test WHERE col_ts_nano_none < (now() - interval '1 hour')"; - let plan = test_sql(sql)?; // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned @@ -55,7 +64,7 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { "Projection: test.col_int32\ \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - assert_eq!(expected, format!("{plan:?}")); + quick_test(sql, expected); Ok(()) } @@ -74,6 +83,11 @@ fn timestamp_nano_ts_utc_predicates() { assert_eq!(expected, format!("{plan:?}")); } +fn quick_test(sql: &str, expected_plan: &str) { + let plan = test_sql(sql).unwrap(); + assert_eq!(expected_plan, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... @@ -81,12 +95,9 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan - let now_udf = datetime::functions() - .iter() - .find(|f| f.name() == "now") - .unwrap() - .to_owned(); - let context_provider = MyContextProvider::default().with_udf(now_udf); + let context_provider = MyContextProvider::default() + .with_udf(datetime::now()) + .with_udf(datafusion_functions::core::arrow_cast()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 3f40c55a3ed7..a58a8cf51681 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -184,11 +184,11 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let expected = [ - "+-------------+", - "| SUM(t.time) |", - "+-------------+", - "| 19000 |", - "+-------------+", + "+---------------------------------------+", + "| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |", + "+---------------------------------------+", + "| 19000 |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs similarity index 90% rename from datafusion/sql/src/expr/arrow_cast.rs rename to datafusion/functions/src/core/arrow_cast.rs index a75cdf9e3c6b..b6c1b5eb9a38 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -15,63 +15,125 @@ // specific language governing permissions and limitations // under the License. -//! Implementation of the `arrow_cast` function that allows -//! casting to arbitrary arrow types (rather than SQL types) +//! [`ArrowCastFunc`]: Implementation of the `arrow_cast` +use std::any::Any; use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{ - plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, + ScalarValue, }; -use datafusion_common::plan_err; -use datafusion_expr::{Expr, ExprSchemable}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -pub const ARROW_CAST_NAME: &str = "arrow_cast"; - -/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// Implements casting to arbitrary arrow types (rather than SQL types) +/// +/// Note that the `arrow_cast` function is somewhat special in that its +/// return depends only on the *value* of its second argument (not its type) /// -/// This function is not a [`BuiltinScalarFunction`] because the -/// return type of [`BuiltinScalarFunction`] depends only on the -/// *types* of the arguments. However, the type of `arrow_type` depends on -/// the *value* of its second argument. +/// It is implemented by calling the same underlying arrow `cast` kernel as +/// normal SQL casts. /// -/// Use the `cast` function to cast to SQL type (which is then mapped -/// to the corresponding arrow type). For example to cast to `int` -/// (which is then mapped to the arrow type `Int32`) +/// For example to cast to `int` using SQL (which is then mapped to the arrow +/// type `Int32`) /// /// ```sql /// select cast(column_x as int) ... /// ``` /// -/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// You can use the `arrow_cast` functiont to cast to a specific arrow type /// /// For example /// ```sql /// select arrow_cast(column_x, 'Float64') /// ``` -/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction -pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { +#[derive(Debug)] +pub(super) struct ArrowCastFunc { + signature: Signature, +} + +impl ArrowCastFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ArrowCastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrow_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default + // implementation + internal_err!("arrow_cast should return type from exprs") + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + _arg_types: &[DataType], + ) -> Result { + data_type_from_args(args) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + internal_err!("arrow_cast should have been simplified to cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + // convert this into a real cast + let target_type = data_type_from_args(&args)?; + // remove second (type) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + arg + } else { + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + data_type: target_type, + }) + }; + // return the newly written argument to DataFusion + Ok(ExprSimplifyResult::Simplified(new_expr)) + } +} + +/// Returns the requested type from the arguments +fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let arg1 = args.pop().unwrap(); - let arg0 = args.pop().unwrap(); - - // arg1 must be a string - let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { - v - } else { + let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { return plan_err!( - "arrow_cast requires its second argument to be a constant string, got {arg1}" + "arrow_cast requires its second argument to be a constant string, got {:?}", + &args[1] ); }; - - // do the actual lookup to the appropriate data type - let data_type = parse_data_type(&data_type_string)?; - - arg0.cast_to(&data_type, schema) + parse_data_type(val) } /// Parses `str` into a `DataType`. @@ -80,22 +142,8 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// -/// Example: -/// ``` -/// # use datafusion_sql::parse_data_type; -/// # use arrow_schema::DataType; -/// let display_value = "Int32"; -/// -/// // "Int32" is the Display value of `DataType` -/// assert_eq!(display_value, &format!("{}", DataType::Int32)); -/// -/// // parse_data_type coverts "Int32" back to `DataType`: -/// let data_type = parse_data_type(display_value).unwrap(); -/// assert_eq!(data_type, DataType::Int32); -/// ``` -/// /// Remove if added to arrow: -pub fn parse_data_type(val: &str) -> Result { +fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } @@ -647,8 +695,6 @@ impl Display for Token { #[cfg(test)] mod test { - use arrow_schema::{IntervalUnit, TimeUnit}; - use super::*; #[test] @@ -844,7 +890,6 @@ mod test { assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); } } - println!(" Ok"); } } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 73cc4d18bf9f..5a0bd2c77f63 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -17,6 +17,7 @@ //! "core" DataFusion functions +mod arrow_cast; mod arrowtypeof; mod getfield; mod nullif; @@ -25,6 +26,7 @@ mod nvl2; mod r#struct; // create UDFs +make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); make_udf_function!(nvl::NVLFunc, NVL, nvl); make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); @@ -35,6 +37,7 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), + (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."), (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ffc951a6fa66..582404b29749 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -34,8 +34,6 @@ use sqlparser::ast::{ use std::str::FromStr; use strum::IntoEnumIterator; -use super::arrow_cast::ARROW_CAST_NAME; - /// Suggest a valid function based on an invalid input function name pub fn suggest_valid_function( input_function_name: &str, @@ -249,12 +247,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { null_treatment, ))); }; - - // Special case arrow_cast (as its type is dependent on its argument value) - if name == ARROW_CAST_NAME { - let args = self.function_args_to_expr(args, schema, planner_context)?; - return super::arrow_cast::create_arrow_cast(args, schema); - } } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a6f1c78c7250..5e9c0623a265 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod arrow_cast; mod binary_op; mod function; mod grouping_set; diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index e8e07eebe22d..12d6a4669634 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -42,5 +42,4 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; -pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index c9c2bdd694b5..b6077353e5dd 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2566,15 +2566,6 @@ fn approx_median_window() { quick_test(sql, expected); } -#[test] -fn select_arrow_cast() { - let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; - let expected = "\ - Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ - \n EmptyRelation"; - quick_test(sql, expected); -} - #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -2670,6 +2661,11 @@ fn logical_plan_with_dialect_and_options( vec![DataType::Int32, DataType::Int32], DataType::Int32, )) + .with_udf(make_udf( + "arrow_cast", + vec![DataType::Int64, DataType::Utf8], + DataType::Float64, + )) .with_udf(make_udf( "date_trunc", vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)], diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 8b3bd7eac95d..3e8694f3b2c2 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -92,10 +92,11 @@ SELECT arrow_cast('1', 'Int16') 1 # Basic error test -query error Error during planning: arrow_cast needs 2 arguments, 1 provided +query error DataFusion error: Error during planning: No function matches the given name and argument types 'arrow_cast\(Utf8\)'. You might need to add explicit type casts. SELECT arrow_cast('1') -query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) + +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown @@ -315,7 +316,7 @@ select arrow_cast(interval '30 minutes', 'Duration(Second)'); ---- 0 days 0 hours 30 mins 0 secs -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) +query error DataFusion error: This feature is not implemented: Unsupported CAST from Utf8 to Duration\(Second\) select arrow_cast('30 minutes', 'Duration(Second)'); @@ -336,7 +337,7 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( ---- 2000-01-01T00:00:00+08:00 -statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); From e53eb036f5c61f7d7bd90047976511628ddca2d0 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Mon, 18 Mar 2024 22:13:14 +0100 Subject: [PATCH 08/35] parquet: Add row_groups_matched_{statistics,bloom_filter} statistics (#9640) * test_row_group_prune: Display which assertion failed * Add row_groups_matched_{statistics,bloom_filter} statistics This helps diagnostic whether a Bloom filter mismatches (because of high false-positive probability caused by suboptimal tuning) or is not used at all. --- .../physical_plan/parquet/metrics.rs | 14 ++ .../physical_plan/parquet/row_groups.rs | 4 + datafusion/core/tests/parquet/mod.rs | 17 +++ .../core/tests/parquet/row_group_pruning.rs | 126 +++++++++++++++++- datafusion/core/tests/sql/explain_analyze.rs | 4 + 5 files changed, 160 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs index a17a3c6d9752..c2a7e4345a5b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs @@ -29,8 +29,12 @@ use crate::physical_plan::metrics::{ pub struct ParquetFileMetrics { /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, + /// Number of row groups whose bloom filters were checked and matched + pub row_groups_matched_bloom_filter: Count, /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: Count, + /// Number of row groups whose statistics were checked and matched + pub row_groups_matched_statistics: Count, /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: Count, /// Total number of bytes scanned @@ -56,10 +60,18 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .counter("predicate_evaluation_errors", partition); + let row_groups_matched_bloom_filter = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_matched_bloom_filter", partition); + let row_groups_pruned_bloom_filter = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("row_groups_pruned_bloom_filter", partition); + let row_groups_matched_statistics = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_matched_statistics", partition); + let row_groups_pruned_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("row_groups_pruned_statistics", partition); @@ -85,7 +97,9 @@ impl ParquetFileMetrics { Self { predicate_evaluation_errors, + row_groups_matched_bloom_filter, row_groups_pruned_bloom_filter, + row_groups_matched_statistics, row_groups_pruned_statistics, bytes_scanned, pushdown_rows_filtered, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index ef2eb775e037..1a84f52a33fd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -94,6 +94,7 @@ pub(crate) fn prune_row_groups_by_statistics( metrics.predicate_evaluation_errors.add(1); } } + metrics.row_groups_matched_statistics.add(1); } filtered.push(idx) @@ -166,6 +167,9 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< if prune_group { metrics.row_groups_pruned_bloom_filter.add(1); } else { + if !stats.column_sbbf.is_empty() { + metrics.row_groups_matched_bloom_filter.add(1); + } filtered.push(*idx); } } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 7649b6acd45c..c60780919489 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -117,16 +117,33 @@ impl TestOutput { self.metric_value("predicate_evaluation_errors") } + /// The number of row_groups matched by bloom filter + fn row_groups_matched_bloom_filter(&self) -> Option { + self.metric_value("row_groups_matched_bloom_filter") + } + /// The number of row_groups pruned by bloom filter fn row_groups_pruned_bloom_filter(&self) -> Option { self.metric_value("row_groups_pruned_bloom_filter") } + /// The number of row_groups matched by statistics + fn row_groups_matched_statistics(&self) -> Option { + self.metric_value("row_groups_matched_statistics") + } + /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { self.metric_value("row_groups_pruned_statistics") } + /// The number of row_groups matched by bloom filter or statistics + fn row_groups_matched(&self) -> Option { + self.row_groups_matched_bloom_filter() + .zip(self.row_groups_matched_statistics()) + .map(|(a, b)| a + b) + } + /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { self.row_groups_pruned_bloom_filter() diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index fa53b9c56cec..b7038ef1a73f 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -29,7 +29,9 @@ struct RowGroupPruningTest { scenario: Scenario, query: String, expected_errors: Option, + expected_row_group_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, + expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, expected_results: usize, } @@ -40,7 +42,9 @@ impl RowGroupPruningTest { scenario: Scenario::Timestamps, // or another default query: String::new(), expected_errors: None, + expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, expected_results: 0, } @@ -64,12 +68,24 @@ impl RowGroupPruningTest { self } + // Set the expected matched row groups by statistics + fn with_matched_by_stats(mut self, matched_by_stats: Option) -> Self { + self.expected_row_group_matched_by_statistics = matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; self } + // Set the expected matched row groups by bloom filter + fn with_matched_by_bloom_filter(mut self, matched_by_bf: Option) -> Self { + self.expected_row_group_matched_by_bloom_filter = matched_by_bf; + self + } + // Set the expected pruned row groups by bloom filter fn with_pruned_by_bloom_filter(mut self, pruned_by_bf: Option) -> Self { self.expected_row_group_pruned_by_bloom_filter = pruned_by_bf; @@ -90,20 +106,36 @@ impl RowGroupPruningTest { .await; println!("{}", output.description()); - assert_eq!(output.predicate_evaluation_errors(), self.expected_errors); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); assert_eq!( output.row_groups_pruned_statistics(), - self.expected_row_group_pruned_by_statistics + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.row_groups_matched_bloom_filter(), + self.expected_row_group_matched_by_bloom_filter, + "mismatched row_groups_matched_bloom_filter", ); assert_eq!( output.row_groups_pruned_bloom_filter(), - self.expected_row_group_pruned_by_bloom_filter + self.expected_row_group_pruned_by_bloom_filter, + "mismatched row_groups_pruned_bloom_filter", ); assert_eq!( output.result_rows, self.expected_results, - "{}", - output.description() + "mismatched expected rows: {}", + output.description(), ); } } @@ -114,7 +146,9 @@ async fn prune_timestamps_nanos() { .with_scenario(Scenario::Timestamps) .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -129,7 +163,9 @@ async fn prune_timestamps_micros() { "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -144,7 +180,9 @@ async fn prune_timestamps_millis() { "SELECT * FROM t where micros < to_timestamp_millis('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -159,7 +197,9 @@ async fn prune_timestamps_seconds() { "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -172,7 +212,9 @@ async fn prune_date32() { .with_scenario(Scenario::Dates) .with_query("SELECT * FROM t where date32 < cast('2020-01-02' as date)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -201,6 +243,7 @@ async fn prune_date64() { println!("{}", output.description()); // This should prune out groups without error assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(1)); assert_eq!(output.row_groups_pruned(), Some(3)); assert_eq!(output.result_rows, 1, "{}", output.description()); } @@ -211,7 +254,9 @@ async fn prune_disabled() { .with_scenario(Scenario::Timestamps) .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -230,6 +275,7 @@ async fn prune_disabled() { // This should not prune any assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(0)); assert_eq!(output.row_groups_pruned(), Some(0)); assert_eq!( output.result_rows, @@ -245,7 +291,9 @@ async fn prune_int32_lt() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i < 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -257,7 +305,9 @@ async fn prune_int32_lt() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where -i > -1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -270,7 +320,9 @@ async fn prune_int32_eq() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -282,7 +334,9 @@ async fn prune_int32_scalar_fun_and_eq() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -295,7 +349,9 @@ async fn prune_int32_scalar_fun() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where abs(i) = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) .test_row_group_prune() @@ -308,7 +364,9 @@ async fn prune_int32_complex_expr() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i+1 = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -321,7 +379,9 @@ async fn prune_int32_complex_expr_subtract() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where 1-i > 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -334,7 +394,9 @@ async fn prune_f64_lt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f < 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -343,7 +405,9 @@ async fn prune_f64_lt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where -f > -1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -358,7 +422,9 @@ async fn prune_f64_scalar_fun_and_gt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -372,7 +438,9 @@ async fn prune_f64_scalar_fun() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f-1) <= 0.000001") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -386,7 +454,9 @@ async fn prune_f64_complex_expr() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f+1 > 1.1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -400,7 +470,9 @@ async fn prune_f64_complex_expr_subtract() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where 1-f > 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -414,7 +486,9 @@ async fn prune_int32_eq_in_list() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i in (1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -429,7 +503,9 @@ async fn prune_int32_eq_in_list_2() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i in (1000)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(4)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(0) .test_row_group_prune() @@ -449,7 +525,9 @@ async fn prune_int32_eq_large_in_list() { .as_str(), ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) .test_row_group_prune() @@ -463,7 +541,9 @@ async fn prune_int32_eq_in_list_negated() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i not in (1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(19) .test_row_group_prune() @@ -479,7 +559,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col < 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -488,7 +570,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -497,7 +581,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col < 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -506,7 +592,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -522,7 +610,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col = 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -531,7 +621,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col = 4.00") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -541,7 +633,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col = 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -550,7 +644,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col = 4.00") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -567,7 +663,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -576,7 +674,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -585,7 +685,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -594,7 +696,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -605,7 +709,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalBloomFilterInt32) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -616,7 +722,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalBloomFilterInt64) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -627,7 +735,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalLargePrecisionBloomFilter) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -644,7 +754,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(7) .test_row_group_prune() @@ -653,7 +765,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"name\" != 'HTTP GET / DISPATCH'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -662,7 +776,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend' AND \"name\" != 'HTTP GET / DISPATCH'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 695b3ba745e2..30b11fe2a0ee 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -737,7 +737,9 @@ async fn parquet_explain_analyze() { // should contain aggregated stats assert_contains!(&formatted, "output_rows=8"); + assert_contains!(&formatted, "row_groups_matched_bloom_filter=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter=0"); + assert_contains!(&formatted, "row_groups_matched_statistics=1"); assert_contains!(&formatted, "row_groups_pruned_statistics=0"); } @@ -754,7 +756,9 @@ async fn parquet_explain_analyze_verbose() { .to_string(); // should contain the raw per file stats (with the label) + assert_contains!(&formatted, "row_groups_matched_bloom_filter{partition=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter{partition=0"); + assert_contains!(&formatted, "row_groups_matched_statistics{partition=0"); assert_contains!(&formatted, "row_groups_pruned_statistics{partition=0"); } From b137f60b9b6132d389efa9911b929d7b4d285b3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Tue, 19 Mar 2024 01:45:26 +0300 Subject: [PATCH 09/35] Make COPY TO align with CREATE EXTERNAL TABLE (#9604) --- datafusion-cli/src/catalog.rs | 2 +- datafusion-cli/src/exec.rs | 6 +- datafusion/common/src/config.rs | 221 +++++++++++++----- datafusion/common/src/file_options/mod.rs | 85 +++---- datafusion/core/src/dataframe/mod.rs | 9 +- datafusion/core/src/dataframe/parquet.rs | 5 +- .../src/datasource/file_format/options.rs | 2 +- .../core/src/datasource/listing/table.rs | 40 ++-- .../src/datasource/listing_table_factory.rs | 6 +- datafusion/core/src/execution/context/mod.rs | 15 +- datafusion/core/src/physical_planner.rs | 2 +- datafusion/core/src/test_util/parquet.rs | 2 +- datafusion/core/tests/sql/sql_api.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 9 +- datafusion/sql/src/parser.rs | 206 ++++++++++++---- datafusion/sql/src/statement.rs | 139 ++++------- datafusion/sql/tests/sql_integration.rs | 26 ++- datafusion/sqllogictest/test_files/copy.slt | 159 ++++++------- .../test_files/create_external_table.slt | 4 +- .../sqllogictest/test_files/csv_files.slt | 10 +- .../sqllogictest/test_files/group_by.slt | 8 +- .../sqllogictest/test_files/parquet.slt | 8 +- .../sqllogictest/test_files/repartition.slt | 2 +- .../test_files/repartition_scan.slt | 8 +- .../test_files/schema_evolution.slt | 8 +- docs/source/user-guide/sql/dml.md | 2 +- 26 files changed, 598 insertions(+), 398 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index a8ecb98637cb..46dd8bb00f06 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -189,7 +189,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { &state, table_url.scheme(), url, - state.default_table_options(), + &state.default_table_options(), ) .await?; state.runtime_env().register_object_store(url, store); diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index b11f1c202284..ea765ee8eceb 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -412,7 +412,7 @@ mod tests { ) })?; for location in locations { - let sql = format!("copy (values (1,2)) to '{}';", location); + let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail @@ -438,8 +438,8 @@ mod tests { let location = "s3://bucket/path/file.parquet"; // Missing region, use object_store defaults - let sql = format!("COPY (values (1,2)) TO '{location}' - (format parquet, 'aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')"); + let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET + OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')"); copy_to_table_test(location, &sql).await?; Ok(()) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 68b9ec9dab94..968d8215ca4d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1109,58 +1109,163 @@ macro_rules! extensions_options { } } +/// Represents the configuration options available for handling different table formats within a data processing application. +/// This struct encompasses options for various file formats including CSV, Parquet, and JSON, allowing for flexible configuration +/// of parsing and writing behaviors specific to each format. Additionally, it supports extending functionality through custom extensions. #[derive(Debug, Clone, Default)] pub struct TableOptions { + /// Configuration options for CSV file handling. This includes settings like the delimiter, + /// quote character, and whether the first row is considered as headers. pub csv: CsvOptions, + + /// Configuration options for Parquet file handling. This includes settings for compression, + /// encoding, and other Parquet-specific file characteristics. pub parquet: TableParquetOptions, + + /// Configuration options for JSON file handling. pub json: JsonOptions, + + /// The current file format that the table operations should assume. This option allows + /// for dynamic switching between the supported file types (e.g., CSV, Parquet, JSON). pub current_format: Option, - /// Optional extensions registered using [`Extensions::insert`] + + /// Optional extensions that can be used to extend or customize the behavior of the table + /// options. Extensions can be registered using `Extensions::insert` and might include + /// custom file handling logic, additional configuration parameters, or other enhancements. pub extensions: Extensions, } impl ConfigField for TableOptions { + /// Visits configuration settings for the current file format, or all formats if none is selected. + /// + /// This method adapts the behavior based on whether a file format is currently selected in `current_format`. + /// If a format is selected, it visits only the settings relevant to that format. Otherwise, + /// it visits all available format settings. fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { - self.csv.visit(v, "csv", ""); - self.parquet.visit(v, "parquet", ""); - self.json.visit(v, "json", ""); + if let Some(file_type) = &self.current_format { + match file_type { + #[cfg(feature = "parquet")] + FileType::PARQUET => self.parquet.visit(v, "format", ""), + FileType::CSV => self.csv.visit(v, "format", ""), + FileType::JSON => self.json.visit(v, "format", ""), + _ => {} + } + } else { + self.csv.visit(v, "csv", ""); + self.parquet.visit(v, "parquet", ""); + self.json.visit(v, "json", ""); + } } + /// Sets a configuration value for a specific key within `TableOptions`. + /// + /// This method delegates setting configuration values to the specific file format configurations, + /// based on the current format selected. If no format is selected, it returns an error. + /// + /// # Parameters + /// + /// * `key`: The configuration key specifying which setting to adjust, prefixed with the format (e.g., "format.delimiter") + /// for CSV format. + /// * `value`: The value to set for the specified configuration key. + /// + /// # Returns + /// + /// A result indicating success or an error if the key is not recognized, if a format is not specified, + /// or if setting the configuration value fails for the specific format. fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; match key { - "csv" => self.csv.set(rem, value), - "parquet" => self.parquet.set(rem, value), - "json" => self.json.set(rem, value), + "format" => match format { + #[cfg(feature = "parquet")] + FileType::PARQUET => self.parquet.set(rem, value), + FileType::CSV => self.csv.set(rem, value), + FileType::JSON => self.json.set(rem, value), + _ => { + _config_err!("Config value \"{key}\" is not supported on {}", format) + } + }, _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } } impl TableOptions { - /// Creates a new [`ConfigOptions`] with default values + /// Constructs a new instance of `TableOptions` with default settings. + /// + /// # Returns + /// + /// A new `TableOptions` instance with default configuration values. pub fn new() -> Self { Self::default() } + /// Sets the file format for the table. + /// + /// # Parameters + /// + /// * `format`: The file format to use (e.g., CSV, Parquet). pub fn set_file_format(&mut self, format: FileType) { self.current_format = Some(format); } + /// Creates a new `TableOptions` instance initialized with settings from a given session config. + /// + /// # Parameters + /// + /// * `config`: A reference to the session `ConfigOptions` from which to derive initial settings. + /// + /// # Returns + /// + /// A new `TableOptions` instance with settings applied from the session config. pub fn default_from_session_config(config: &ConfigOptions) -> Self { - let mut initial = TableOptions::default(); - initial.parquet.global = config.execution.parquet.clone(); + let initial = TableOptions::default(); + initial.combine_with_session_config(config); initial } - /// Set extensions to provided value + /// Updates the current `TableOptions` with settings from a given session config. + /// + /// # Parameters + /// + /// * `config`: A reference to the session `ConfigOptions` whose settings are to be applied. + /// + /// # Returns + /// + /// A new `TableOptions` instance with updated settings from the session config. + pub fn combine_with_session_config(&self, config: &ConfigOptions) -> Self { + let mut clone = self.clone(); + clone.parquet.global = config.execution.parquet.clone(); + clone + } + + /// Sets the extensions for this `TableOptions` instance. + /// + /// # Parameters + /// + /// * `extensions`: The `Extensions` instance to set. + /// + /// # Returns + /// + /// A new `TableOptions` instance with the specified extensions applied. pub fn with_extensions(mut self, extensions: Extensions) -> Self { self.extensions = extensions; self } - /// Set a configuration option + /// Sets a specific configuration option. + /// + /// # Parameters + /// + /// * `key`: The configuration key (e.g., "format.delimiter"). + /// * `value`: The value to set for the specified key. + /// + /// # Returns + /// + /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { let (prefix, _) = key.split_once('.').ok_or_else(|| { DataFusionError::Configuration(format!( @@ -1168,28 +1273,7 @@ impl TableOptions { )) })?; - if prefix == "csv" || prefix == "json" || prefix == "parquet" { - if let Some(format) = &self.current_format { - match format { - FileType::CSV if prefix != "csv" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for CSV format" - ))) - } - #[cfg(feature = "parquet")] - FileType::PARQUET if prefix != "parquet" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for PARQUET format" - ))) - } - FileType::JSON if prefix != "json" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for JSON format" - ))) - } - _ => {} - } - } + if prefix == "format" { return ConfigField::set(self, key, value); } @@ -1202,6 +1286,15 @@ impl TableOptions { e.0.set(key, value) } + /// Initializes a new `TableOptions` from a hash map of string settings. + /// + /// # Parameters + /// + /// * `settings`: A hash map where each key-value pair represents a configuration setting. + /// + /// # Returns + /// + /// A result containing the new `TableOptions` instance or an error if any setting could not be applied. pub fn from_string_hash_map(settings: &HashMap) -> Result { let mut ret = Self::default(); for (k, v) in settings { @@ -1211,6 +1304,15 @@ impl TableOptions { Ok(ret) } + /// Modifies the current `TableOptions` instance with settings from a hash map. + /// + /// # Parameters + /// + /// * `settings`: A hash map where each key-value pair represents a configuration setting. + /// + /// # Returns + /// + /// A result indicating success or failure in applying the settings. pub fn alter_with_string_hash_map( &mut self, settings: &HashMap, @@ -1221,7 +1323,11 @@ impl TableOptions { Ok(()) } - /// Returns the [`ConfigEntry`] stored within this [`ConfigOptions`] + /// Retrieves all configuration entries from this `TableOptions`. + /// + /// # Returns + /// + /// A vector of `ConfigEntry` instances, representing all the configuration options within this `TableOptions`. pub fn entries(&self) -> Vec { struct Visitor(Vec); @@ -1249,9 +1355,7 @@ impl TableOptions { } let mut v = Visitor(vec![]); - self.visit(&mut v, "csv", ""); - self.visit(&mut v, "json", ""); - self.visit(&mut v, "parquet", ""); + self.visit(&mut v, "format", ""); v.0.extend(self.extensions.0.values().flat_map(|e| e.0.entries())); v.0 @@ -1556,6 +1660,7 @@ mod tests { use crate::config::{ ConfigEntry, ConfigExtension, ExtensionOptions, Extensions, TableOptions, }; + use crate::FileType; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -1609,12 +1714,13 @@ mod tests { } #[test] - fn alter_kafka_config() { + fn alter_test_extension_config() { let mut extension = Extensions::new(); extension.insert(TestExtensionConfig::default()); let mut table_config = TableOptions::new().with_extensions(extension); - table_config.set("parquet.write_batch_size", "10").unwrap(); - assert_eq!(table_config.parquet.global.write_batch_size, 10); + table_config.set_file_format(FileType::CSV); + table_config.set("format.delimiter", ";").unwrap(); + assert_eq!(table_config.csv.delimiter, b';'); table_config.set("test.bootstrap.servers", "asd").unwrap(); let kafka_config = table_config .extensions @@ -1626,11 +1732,25 @@ mod tests { ); } + #[test] + fn csv_u8_table_options() { + let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::CSV); + table_config.set("format.delimiter", ";").unwrap(); + assert_eq!(table_config.csv.delimiter as char, ';'); + table_config.set("format.escape", "\"").unwrap(); + assert_eq!(table_config.csv.escape.unwrap() as char, '"'); + table_config.set("format.escape", "\'").unwrap(); + assert_eq!(table_config.csv.escape.unwrap() as char, '\''); + } + + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config - .set("parquet.bloom_filter_enabled::col1", "true") + .set("format.bloom_filter_enabled::col1", "true") .unwrap(); assert_eq!( table_config.parquet.column_specific_options["col1"].bloom_filter_enabled, @@ -1638,26 +1758,17 @@ mod tests { ); } - #[test] - fn csv_u8_table_options() { - let mut table_config = TableOptions::new(); - table_config.set("csv.delimiter", ";").unwrap(); - assert_eq!(table_config.csv.delimiter as char, ';'); - table_config.set("csv.escape", "\"").unwrap(); - assert_eq!(table_config.csv.escape.unwrap() as char, '"'); - table_config.set("csv.escape", "\'").unwrap(); - assert_eq!(table_config.csv.escape.unwrap() as char, '\''); - } - + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config - .set("parquet.bloom_filter_enabled::col1", "true") + .set("format.bloom_filter_enabled::col1", "true") .unwrap(); let entries = table_config.entries(); assert!(entries .iter() - .any(|item| item.key == "parquet.bloom_filter_enabled::col1")) + .any(|item| item.key == "format.bloom_filter_enabled::col1")) } } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index a72b812adc8d..eb1ce1b364fd 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -35,7 +35,7 @@ mod tests { config::TableOptions, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - Result, + FileType, Result, }; use parquet::{ @@ -47,35 +47,36 @@ mod tests { #[test] fn test_writeroptions_parquet_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("parquet.max_row_group_size".to_owned(), "123".to_owned()); - option_map.insert("parquet.data_pagesize_limit".to_owned(), "123".to_owned()); - option_map.insert("parquet.write_batch_size".to_owned(), "123".to_owned()); - option_map.insert("parquet.writer_version".to_owned(), "2.0".to_owned()); + option_map.insert("format.max_row_group_size".to_owned(), "123".to_owned()); + option_map.insert("format.data_pagesize_limit".to_owned(), "123".to_owned()); + option_map.insert("format.write_batch_size".to_owned(), "123".to_owned()); + option_map.insert("format.writer_version".to_owned(), "2.0".to_owned()); option_map.insert( - "parquet.dictionary_page_size_limit".to_owned(), + "format.dictionary_page_size_limit".to_owned(), "123".to_owned(), ); option_map.insert( - "parquet.created_by".to_owned(), + "format.created_by".to_owned(), "df write unit test".to_owned(), ); option_map.insert( - "parquet.column_index_truncate_length".to_owned(), + "format.column_index_truncate_length".to_owned(), "123".to_owned(), ); option_map.insert( - "parquet.data_page_row_count_limit".to_owned(), + "format.data_page_row_count_limit".to_owned(), "123".to_owned(), ); - option_map.insert("parquet.bloom_filter_enabled".to_owned(), "true".to_owned()); - option_map.insert("parquet.encoding".to_owned(), "plain".to_owned()); - option_map.insert("parquet.dictionary_enabled".to_owned(), "true".to_owned()); - option_map.insert("parquet.compression".to_owned(), "zstd(4)".to_owned()); - option_map.insert("parquet.statistics_enabled".to_owned(), "page".to_owned()); - option_map.insert("parquet.bloom_filter_fpp".to_owned(), "0.123".to_owned()); - option_map.insert("parquet.bloom_filter_ndv".to_owned(), "123".to_owned()); + option_map.insert("format.bloom_filter_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.encoding".to_owned(), "plain".to_owned()); + option_map.insert("format.dictionary_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.compression".to_owned(), "zstd(4)".to_owned()); + option_map.insert("format.statistics_enabled".to_owned(), "page".to_owned()); + option_map.insert("format.bloom_filter_fpp".to_owned(), "0.123".to_owned()); + option_map.insert("format.bloom_filter_ndv".to_owned(), "123".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -131,54 +132,52 @@ mod tests { let mut option_map: HashMap = HashMap::new(); option_map.insert( - "parquet.bloom_filter_enabled::col1".to_owned(), + "format.bloom_filter_enabled::col1".to_owned(), "true".to_owned(), ); option_map.insert( - "parquet.bloom_filter_enabled::col2.nested".to_owned(), + "format.bloom_filter_enabled::col2.nested".to_owned(), "true".to_owned(), ); - option_map.insert("parquet.encoding::col1".to_owned(), "plain".to_owned()); - option_map.insert("parquet.encoding::col2.nested".to_owned(), "rle".to_owned()); + option_map.insert("format.encoding::col1".to_owned(), "plain".to_owned()); + option_map.insert("format.encoding::col2.nested".to_owned(), "rle".to_owned()); option_map.insert( - "parquet.dictionary_enabled::col1".to_owned(), + "format.dictionary_enabled::col1".to_owned(), "true".to_owned(), ); option_map.insert( - "parquet.dictionary_enabled::col2.nested".to_owned(), + "format.dictionary_enabled::col2.nested".to_owned(), "true".to_owned(), ); - option_map.insert("parquet.compression::col1".to_owned(), "zstd(4)".to_owned()); + option_map.insert("format.compression::col1".to_owned(), "zstd(4)".to_owned()); option_map.insert( - "parquet.compression::col2.nested".to_owned(), + "format.compression::col2.nested".to_owned(), "zstd(10)".to_owned(), ); option_map.insert( - "parquet.statistics_enabled::col1".to_owned(), + "format.statistics_enabled::col1".to_owned(), "page".to_owned(), ); option_map.insert( - "parquet.statistics_enabled::col2.nested".to_owned(), + "format.statistics_enabled::col2.nested".to_owned(), "none".to_owned(), ); option_map.insert( - "parquet.bloom_filter_fpp::col1".to_owned(), + "format.bloom_filter_fpp::col1".to_owned(), "0.123".to_owned(), ); option_map.insert( - "parquet.bloom_filter_fpp::col2.nested".to_owned(), + "format.bloom_filter_fpp::col2.nested".to_owned(), "0.456".to_owned(), ); + option_map.insert("format.bloom_filter_ndv::col1".to_owned(), "123".to_owned()); option_map.insert( - "parquet.bloom_filter_ndv::col1".to_owned(), - "123".to_owned(), - ); - option_map.insert( - "parquet.bloom_filter_ndv::col2.nested".to_owned(), + "format.bloom_filter_ndv::col2.nested".to_owned(), "456".to_owned(), ); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -271,16 +270,17 @@ mod tests { // for StatementOptions fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("csv.has_header".to_owned(), "true".to_owned()); - option_map.insert("csv.date_format".to_owned(), "123".to_owned()); - option_map.insert("csv.datetime_format".to_owned(), "123".to_owned()); - option_map.insert("csv.timestamp_format".to_owned(), "2.0".to_owned()); - option_map.insert("csv.time_format".to_owned(), "123".to_owned()); - option_map.insert("csv.null_value".to_owned(), "123".to_owned()); - option_map.insert("csv.compression".to_owned(), "gzip".to_owned()); - option_map.insert("csv.delimiter".to_owned(), ";".to_owned()); + option_map.insert("format.has_header".to_owned(), "true".to_owned()); + option_map.insert("format.date_format".to_owned(), "123".to_owned()); + option_map.insert("format.datetime_format".to_owned(), "123".to_owned()); + option_map.insert("format.timestamp_format".to_owned(), "2.0".to_owned()); + option_map.insert("format.time_format".to_owned(), "123".to_owned()); + option_map.insert("format.null_value".to_owned(), "123".to_owned()); + option_map.insert("format.compression".to_owned(), "gzip".to_owned()); + option_map.insert("format.delimiter".to_owned(), ";".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::CSV); table_config.alter_with_string_hash_map(&option_map)?; let csv_options = CsvWriterOptions::try_from(&table_config.csv)?; @@ -299,9 +299,10 @@ mod tests { // for StatementOptions fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("json.compression".to_owned(), "gzip".to_owned()); + option_map.insert("format.compression".to_owned(), "gzip".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::JSON); table_config.alter_with_string_hash_map(&option_map)?; let json_options = JsonWriterOptions::try_from(&table_config.json)?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 25830401571d..eea5fc1127ce 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1151,8 +1151,8 @@ impl DataFrame { "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), )); } - let table_options = self.session_state.default_table_options(); - let props = writer_options.unwrap_or_else(|| table_options.csv.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().csv); let plan = LogicalPlanBuilder::copy_to( self.plan, @@ -1200,9 +1200,8 @@ impl DataFrame { )); } - let table_options = self.session_state.default_table_options(); - - let props = writer_options.unwrap_or_else(|| table_options.json.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().json); let plan = LogicalPlanBuilder::copy_to( self.plan, diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index f4e8c9dfcd6f..e3f606e322fe 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -57,9 +57,8 @@ impl DataFrame { )); } - let table_options = self.session_state.default_table_options(); - - let props = writer_options.unwrap_or_else(|| table_options.parquet.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().parquet); let plan = LogicalPlanBuilder::copy_to( self.plan, diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index f66683c311c1..f5bd72495d66 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -461,7 +461,7 @@ pub trait ReadOptions<'a> { return Ok(Arc::new(s.to_owned())); } - self.to_listing_options(config, state.default_table_options().clone()) + self.to_listing_options(config, state.default_table_options()) .infer_schema(&state, &table_path) .await } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 2a2551236e1b..c1e337b5c44a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -118,7 +118,7 @@ impl ListingTableConfig { } } - fn infer_format(path: &str) -> Result<(Arc, String)> { + fn infer_file_type(path: &str) -> Result<(FileType, String)> { let err_msg = format!("Unable to infer file type from path: {path}"); let mut exts = path.rsplit('.'); @@ -139,20 +139,7 @@ impl ListingTableConfig { .get_ext_with_compression(file_compression_type.to_owned()) .map_err(|_| DataFusionError::Internal(err_msg))?; - let file_format: Arc = match file_type { - FileType::ARROW => Arc::new(ArrowFormat), - FileType::AVRO => Arc::new(AvroFormat), - FileType::CSV => Arc::new( - CsvFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::JSON => Arc::new( - JsonFormat::default().with_file_compression_type(file_compression_type), - ), - #[cfg(feature = "parquet")] - FileType::PARQUET => Arc::new(ParquetFormat::default()), - }; - - Ok((file_format, ext)) + Ok((file_type, ext)) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -173,10 +160,27 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let (format, file_extension) = - ListingTableConfig::infer_format(file.location.as_ref())?; + let (file_type, file_extension) = + ListingTableConfig::infer_file_type(file.location.as_ref())?; + + let mut table_options = state.default_table_options(); + table_options.set_file_format(file_type.clone()); + let file_format: Arc = match file_type { + FileType::CSV => { + Arc::new(CsvFormat::default().with_options(table_options.csv)) + } + #[cfg(feature = "parquet")] + FileType::PARQUET => { + Arc::new(ParquetFormat::default().with_options(table_options.parquet)) + } + FileType::AVRO => Arc::new(AvroFormat), + FileType::JSON => { + Arc::new(JsonFormat::default().with_options(table_options.json)) + } + FileType::ARROW => Arc::new(ArrowFormat), + }; - let listing_options = ListingOptions::new(format) + let listing_options = ListingOptions::new(file_format) .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()); diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 4e126bbba9f9..b616e0181cfc 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -34,7 +34,6 @@ use crate::datasource::TableProvider; use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::config::TableOptions; use datafusion_common::{arrow_datafusion_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; @@ -58,8 +57,7 @@ impl TableProviderFactory for ListingTableFactory { state: &SessionState, cmd: &CreateExternalTable, ) -> datafusion_common::Result> { - let mut table_options = - TableOptions::default_from_session_config(state.config_options()); + let mut table_options = state.default_table_options(); let file_type = FileType::from_str(cmd.file_type.as_str()).map_err(|_| { DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) })?; @@ -227,7 +225,7 @@ mod tests { let name = OwnedTableReference::bare("foo".to_string()); let mut options = HashMap::new(); - options.insert("csv.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); let cmd = CreateExternalTable { name, location: csv_file.path().to_str().unwrap().to_string(), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 1ac7da465216..116e45c8c130 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -384,9 +384,9 @@ impl SessionContext { self.state.read().config.clone() } - /// Return a copied version of config for this Session + /// Return a copied version of table options for this Session pub fn copied_table_options(&self) -> TableOptions { - self.state.read().default_table_options().clone() + self.state.read().default_table_options() } /// Creates a [`DataFrame`] from SQL query text. @@ -1750,11 +1750,7 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::CopyTo(CopyToStatement { - source, - target: _, - options: _, - }) => match source { + DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { visitor.insert(table_name); } @@ -1963,8 +1959,9 @@ impl SessionState { } /// return the TableOptions options with its extensions - pub fn default_table_options(&self) -> &TableOptions { - &self.table_option_namespace + pub fn default_table_options(&self) -> TableOptions { + self.table_option_namespace + .combine_with_session_config(self.config_options()) } /// Get a new TaskContext to run in this session diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 96f5e1c3ffd3..ee581ca64214 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -595,7 +595,7 @@ impl DefaultPhysicalPlanner { table_partition_cols, overwrite: false, }; - let mut table_options = session_state.default_table_options().clone(); + let mut table_options = session_state.default_table_options(); let sink_format: Arc = match format_options { FormatOptions::CSV(options) => { table_options.csv = options.clone(); diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 7a466a666d8d..8113d799a184 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -165,7 +165,7 @@ impl TestParquetFile { // run coercion on the filters to coerce types etc. let props = ExecutionProps::new(); let context = SimplifyContext::new(&props).with_schema(df_schema.clone()); - let parquet_options = ctx.state().default_table_options().parquet.clone(); + let parquet_options = ctx.copied_table_options().parquet; if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, df_schema.clone()).unwrap(); diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index d7adc9611b2f..b3a819fbc331 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::prelude::*; + use tempfile::TempDir; #[tokio::test] @@ -27,7 +28,7 @@ async fn unsupported_ddl_returns_error() { // disallow ddl let options = SQLOptions::new().with_allow_ddl(false); - let sql = "create view test_view as select * from test"; + let sql = "CREATE VIEW test_view AS SELECT * FROM test"; let df = ctx.sql_with_options(sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -46,7 +47,7 @@ async fn unsupported_dml_returns_error() { let options = SQLOptions::new().with_allow_dml(false); - let sql = "insert into test values (1)"; + let sql = "INSERT INTO test VALUES (1)"; let df = ctx.sql_with_options(sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -67,7 +68,10 @@ async fn unsupported_copy_returns_error() { let options = SQLOptions::new().with_allow_dml(false); - let sql = format!("copy (values(1)) to '{}'", tmpfile.to_string_lossy()); + let sql = format!( + "COPY (values(1)) TO '{}' STORED AS parquet", + tmpfile.to_string_lossy() + ); let df = ctx.sql_with_options(&sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -106,7 +110,7 @@ async fn ddl_can_not_be_planned_by_session_state() { let state = ctx.state(); // can not create a logical plan for catalog DDL - let sql = "drop table test"; + let sql = "DROP TABLE test"; let plan = state.create_logical_plan(sql).await.unwrap(); let physical_plan = state.create_physical_plan(&plan).await; assert_eq!( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 93de560dbee5..3c43f100750f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -35,7 +35,7 @@ use datafusion_common::config::{FormatOptions, TableOptions}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ internal_err, not_impl_err, plan_err, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, + DataFusionError, FileType, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -314,10 +314,9 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new(); let input = create_csv_scan(&ctx).await?; - - let mut table_options = - TableOptions::default_from_session_config(ctx.state().config_options()); - table_options.set("csv.delimiter", ";")?; + let mut table_options = ctx.copied_table_options(); + table_options.set_file_format(FileType::CSV); + table_options.set("format.delimiter", ";")?; let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index effc1d096cfd..a5d7970495c5 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -17,21 +17,20 @@ //! [`DFParser`]: DataFusion SQL Parser based on [`sqlparser`] +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::str::FromStr; + use datafusion_common::parsers::CompressionTypeVariant; -use sqlparser::ast::{OrderByExpr, Query, Value}; -use sqlparser::tokenizer::Word; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, ObjectName, Statement as SQLStatement, - TableConstraint, + ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, + Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, parser::{Parser, ParserError}, - tokenizer::{Token, TokenWithLocation, Tokenizer}, + tokenizer::{Token, TokenWithLocation, Tokenizer, Word}, }; -use std::collections::VecDeque; -use std::fmt; -use std::{collections::HashMap, str::FromStr}; // Use `Parser::expected` instead, if possible macro_rules! parser_err { @@ -102,6 +101,12 @@ pub struct CopyToStatement { pub source: CopyToSource, /// The URL to where the data is heading pub target: String, + /// Partition keys + pub partitioned_by: Vec, + /// Indicates whether there is a header row (e.g. CSV) + pub has_header: bool, + /// File type (Parquet, NDJSON, CSV etc.) + pub stored_as: Option, /// Target specific options pub options: Vec<(String, Value)>, } @@ -111,15 +116,27 @@ impl fmt::Display for CopyToStatement { let Self { source, target, + partitioned_by, + stored_as, options, + .. } = self; write!(f, "COPY {source} TO {target}")?; + if let Some(file_type) = stored_as { + write!(f, " STORED AS {}", file_type)?; + } + if !partitioned_by.is_empty() { + write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; + } + + if self.has_header { + write!(f, " WITH HEADER ROW")?; + } if !options.is_empty() { let opts: Vec<_> = options.iter().map(|(k, v)| format!("{k} {v}")).collect(); - // print them in sorted order - write!(f, " ({})", opts.join(", "))?; + write!(f, " OPTIONS ({})", opts.join(", "))?; } Ok(()) @@ -243,6 +260,15 @@ impl fmt::Display for Statement { } } +fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { + if field.is_some() { + return Err(ParserError::ParserError(format!( + "{name} specified more than once", + ))); + } + Ok(()) +} + /// Datafusion SQL Parser based on [`sqlparser`] /// /// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. @@ -370,21 +396,79 @@ impl<'a> DFParser<'a> { CopyToSource::Relation(table_name) }; - self.parser.expect_keyword(Keyword::TO)?; + #[derive(Default)] + struct Builder { + stored_as: Option, + target: Option, + partitioned_by: Option>, + has_header: Option, + options: Option>, + } - let target = self.parser.parse_literal_string()?; + let mut builder = Builder::default(); - // check for options in parens - let options = if self.parser.peek_token().token == Token::LParen { - self.parse_value_options()? - } else { - vec![] + loop { + if let Some(keyword) = self.parser.parse_one_of_keywords(&[ + Keyword::STORED, + Keyword::TO, + Keyword::PARTITIONED, + Keyword::OPTIONS, + Keyword::WITH, + ]) { + match keyword { + Keyword::STORED => { + self.parser.expect_keyword(Keyword::AS)?; + ensure_not_set(&builder.stored_as, "STORED AS")?; + builder.stored_as = Some(self.parse_file_format()?); + } + Keyword::TO => { + ensure_not_set(&builder.target, "TO")?; + builder.target = Some(self.parser.parse_literal_string()?); + } + Keyword::WITH => { + self.parser.expect_keyword(Keyword::HEADER)?; + self.parser.expect_keyword(Keyword::ROW)?; + ensure_not_set(&builder.has_header, "WITH HEADER ROW")?; + builder.has_header = Some(true); + } + Keyword::PARTITIONED => { + self.parser.expect_keyword(Keyword::BY)?; + ensure_not_set(&builder.partitioned_by, "PARTITIONED BY")?; + builder.partitioned_by = Some(self.parse_partitions()?); + } + Keyword::OPTIONS => { + ensure_not_set(&builder.options, "OPTIONS")?; + builder.options = Some(self.parse_value_options()?); + } + _ => { + unreachable!() + } + } + } else { + let token = self.parser.next_token(); + if token == Token::EOF || token == Token::SemiColon { + break; + } else { + return Err(ParserError::ParserError(format!( + "Unexpected token {token}" + ))); + } + } + } + + let Some(target) = builder.target else { + return Err(ParserError::ParserError( + "Missing TO clause in COPY statement".into(), + )); }; Ok(Statement::CopyTo(CopyToStatement { source, target, - options, + partitioned_by: builder.partitioned_by.unwrap_or(vec![]), + has_header: builder.has_header.unwrap_or(false), + stored_as: builder.stored_as, + options: builder.options.unwrap_or(vec![]), })) } @@ -624,15 +708,6 @@ impl<'a> DFParser<'a> { } let mut builder = Builder::default(); - fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { - if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); - } - Ok(()) - } - loop { if let Some(keyword) = self.parser.parse_one_of_keywords(&[ Keyword::STORED, @@ -1321,10 +1396,13 @@ mod tests { #[test] fn copy_to_table_to_table() -> Result<(), ParserError> { // positive case - let sql = "COPY foo TO bar"; + let sql = "COPY foo TO bar STORED AS CSV"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("CSV".to_owned()), options: vec![], }); @@ -1335,10 +1413,22 @@ mod tests { #[test] fn explain_copy_to_table_to_table() -> Result<(), ParserError> { let cases = vec![ - ("EXPLAIN COPY foo TO bar", false, false), - ("EXPLAIN ANALYZE COPY foo TO bar", true, false), - ("EXPLAIN VERBOSE COPY foo TO bar", false, true), - ("EXPLAIN ANALYZE VERBOSE COPY foo TO bar", true, true), + ("EXPLAIN COPY foo TO bar STORED AS PARQUET", false, false), + ( + "EXPLAIN ANALYZE COPY foo TO bar STORED AS PARQUET", + true, + false, + ), + ( + "EXPLAIN VERBOSE COPY foo TO bar STORED AS PARQUET", + false, + true, + ), + ( + "EXPLAIN ANALYZE VERBOSE COPY foo TO bar STORED AS PARQUET", + true, + true, + ), ]; for (sql, analyze, verbose) in cases { println!("sql: {sql}, analyze: {analyze}, verbose: {verbose}"); @@ -1346,6 +1436,9 @@ mod tests { let expected_copy = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("PARQUET".to_owned()), options: vec![], }); let expected = Statement::Explain(ExplainStatement { @@ -1375,10 +1468,13 @@ mod tests { panic!("Expected query, got {statement:?}"); }; - let sql = "COPY (SELECT 1) TO bar"; + let sql = "COPY (SELECT 1) TO bar STORED AS CSV WITH HEADER ROW"; let expected = Statement::CopyTo(CopyToStatement { source: CopyToSource::Query(query), target: "bar".to_string(), + partitioned_by: vec![], + has_header: true, + stored_as: Some("CSV".to_owned()), options: vec![], }); assert_eq!(verified_stmt(sql), expected); @@ -1387,10 +1483,31 @@ mod tests { #[test] fn copy_to_options() -> Result<(), ParserError> { - let sql = "COPY foo TO bar (row_group_size 55)"; + let sql = "COPY foo TO bar STORED AS CSV OPTIONS (row_group_size 55)"; + let expected = Statement::CopyTo(CopyToStatement { + source: object_name("foo"), + target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("CSV".to_owned()), + options: vec![( + "row_group_size".to_string(), + Value::Number("55".to_string(), false), + )], + }); + assert_eq!(verified_stmt(sql), expected); + Ok(()) + } + + #[test] + fn copy_to_partitioned_by() -> Result<(), ParserError> { + let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS (row_group_size 55)"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec!["a".to_string()], + has_header: false, + stored_as: Some("CSV".to_owned()), options: vec![( "row_group_size".to_string(), Value::Number("55".to_string(), false), @@ -1404,24 +1521,24 @@ mod tests { fn copy_to_multi_options() -> Result<(), ParserError> { // order of options is preserved let sql = - "COPY foo TO bar (format parquet, row_group_size 55, compression snappy)"; + "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy)"; let expected_options = vec![ ( - "format".to_string(), - Value::UnQuotedString("parquet".to_string()), - ), - ( - "row_group_size".to_string(), + "format.row_group_size".to_string(), Value::Number("55".to_string(), false), ), ( - "compression".to_string(), + "format.compression".to_string(), Value::UnQuotedString("snappy".to_string()), ), ]; - let options = if let Statement::CopyTo(copy_to) = verified_stmt(sql) { + let mut statements = DFParser::parse_sql(sql).unwrap(); + assert_eq!(statements.len(), 1); + let only_statement = statements.pop_front().unwrap(); + + let options = if let Statement::CopyTo(copy_to) = only_statement { copy_to.options } else { panic!("Expected copy"); @@ -1460,7 +1577,10 @@ mod tests { } let only_statement = statements.pop_front().unwrap(); - assert_eq!(canonical, only_statement.to_string()); + assert_eq!( + canonical.to_uppercase(), + only_statement.to_string().to_uppercase() + ); only_statement } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 412c3b753ed5..e50aceb757df 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -813,20 +813,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn copy_to_plan(&self, statement: CopyToStatement) -> Result { // determine if source is table or query and handle accordingly let copy_source = statement.source; - let input = match copy_source { + let (input, input_schema, table_ref) = match copy_source { CopyToSource::Relation(object_name) => { - let table_ref = - self.object_name_to_table_reference(object_name.clone())?; - let table_source = self.context_provider.get_table_source(table_ref)?; - LogicalPlanBuilder::scan( - object_name_to_string(&object_name), - table_source, - None, - )? - .build()? + let table_name = object_name_to_string(&object_name); + let table_ref = self.object_name_to_table_reference(object_name)?; + let table_source = + self.context_provider.get_table_source(table_ref.clone())?; + let plan = + LogicalPlanBuilder::scan(table_name, table_source, None)?.build()?; + let input_schema = plan.schema().clone(); + (plan, input_schema, Some(table_ref)) } CopyToSource::Query(query) => { - self.query_to_plan(query, &mut PlannerContext::new())? + let plan = self.query_to_plan(query, &mut PlannerContext::new())?; + let input_schema = plan.schema().clone(); + (plan, input_schema, None) } }; @@ -852,8 +853,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { options.insert(key.to_lowercase(), value_string.to_lowercase()); } - let file_type = try_infer_file_type(&mut options, &statement.target)?; - let partition_by = take_partition_by(&mut options); + let file_type = if let Some(file_type) = statement.stored_as { + FileType::from_str(&file_type).map_err(|_| { + DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) + })? + } else { + let e = || { + DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) + }; + // try to infer file format from file extension + let extension: &str = &Path::new(&statement.target) + .extension() + .ok_or_else(e)? + .to_str() + .ok_or_else(e)? + .to_lowercase(); + + FileType::from_str(extension).map_err(|e| { + DataFusionError::Configuration(format!( + "{}. Use STORED AS to define file format.", + e + )) + })? + }; + + let partition_by = statement + .partitioned_by + .iter() + .map(|col| input_schema.field_with_name(table_ref.as_ref(), col)) + .collect::>>()? + .into_iter() + .map(|f| f.name().to_owned()) + .collect(); Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -1469,82 +1503,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .is_ok() } } - -/// Infers the file type for a given target based on provided options or file extension. -/// -/// This function tries to determine the file type based on the 'format' option present -/// in the provided options hashmap. If 'format' is not explicitly set, the function attempts -/// to infer the file type from the file extension of the target. It returns an error if neither -/// the format option is set nor the file extension can be determined or parsed. -/// -/// # Arguments -/// -/// * `options` - A mutable reference to a HashMap containing options where the file format -/// might be specified under the 'format' key. -/// * `target` - A string slice representing the path to the file for which the file type needs to be inferred. -/// -/// # Returns -/// -/// Returns `Result` which is Ok if the file type could be successfully inferred, -/// otherwise returns an error in case of failure to determine or parse the file format or extension. -/// -/// # Errors -/// -/// This function returns an error in two cases: -/// - If the 'format' option is not set and the file extension cannot be retrieved from `target`. -/// - If the file extension is found but cannot be converted into a valid string. -/// -pub fn try_infer_file_type( - options: &mut HashMap, - target: &str, -) -> Result { - let explicit_format = options.remove("format"); - let format = match explicit_format { - Some(s) => FileType::from_str(&s), - None => { - // try to infer file format from file extension - let extension: &str = &Path::new(target) - .extension() - .ok_or(DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension!" - .to_string(), - ))? - .to_str() - .ok_or(DataFusionError::Configuration( - "Format not explicitly set and failed to parse file extension!" - .to_string(), - ))? - .to_lowercase(); - - FileType::from_str(extension) - } - }?; - - Ok(format) -} - -/// Extracts and parses the 'partition_by' option from a provided options hashmap. -/// -/// This function looks for a 'partition_by' key in the options hashmap. If found, -/// it splits the value by commas, trims each resulting string, and replaces double -/// single quotes with a single quote. It returns a vector of partition column names. -/// -/// # Arguments -/// -/// * `options` - A mutable reference to a HashMap containing options where 'partition_by' -/// might be specified. -/// -/// # Returns -/// -/// Returns a `Vec` containing partition column names. If the 'partition_by' option -/// is not present, returns an empty vector. -pub fn take_partition_by(options: &mut HashMap) -> Vec { - let partition_by = options.remove("partition_by"); - match partition_by { - Some(part_cols) => part_cols - .split(',') - .map(|s| s.trim().replace("''", "'")) - .collect::>(), - None => vec![], - } -} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b6077353e5dd..6d335f1f8fc9 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,25 +22,23 @@ use std::{sync::Arc, vec}; use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; -use datafusion_sql::planner::PlannerContext; -use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; -use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; - +use datafusion_common::config::ConfigOptions; use datafusion_common::{ - config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, + plan_err, DFSchema, DataFusionError, ParamValues, Result, ScalarValue, TableReference, }; -use datafusion_common::{plan_err, DFSchema, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, Volatility, WindowUDF, }; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; use datafusion_sql::{ parser::DFParser, - planner::{ContextProvider, ParserOptions, SqlToRel}, + planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; use rstest::rstest; +use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; #[test] @@ -389,7 +387,7 @@ fn plan_rollback_transaction_chained() { #[test] fn plan_copy_to() { - let sql = "COPY test_decimal to 'output.csv'"; + let sql = "COPY test_decimal to 'output.csv' STORED AS CSV"; let plan = r#" CopyTo: format=csv output_url=output.csv options: () TableScan: test_decimal @@ -410,6 +408,18 @@ Explain quick_test(sql, plan); } +#[test] +fn plan_explain_copy_to_format() { + let sql = "EXPLAIN COPY test_decimal to 'output.tbl' STORED AS CSV"; + let plan = r#" +Explain + CopyTo: format=csv output_url=output.tbl options: () + TableScan: test_decimal + "# + .trim(); + quick_test(sql, plan); +} + #[test] fn plan_copy_to_query() { let sql = "COPY (select * from test_decimal limit 10) to 'output.csv'"; diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index df23a993ebce..4d4f596d0c60 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -21,13 +21,13 @@ create table source_table(col1 integer, col2 varchar) as values (1, 'Foo'), (2, # Copy to directory as multiple files query IT -COPY source_table TO 'test_files/scratch/copy/table/' (format parquet, 'parquet.compression' 'zstd(10)'); +COPY source_table TO 'test_files/scratch/copy/table/' STORED AS parquet OPTIONS ('format.compression' 'zstd(10)'); ---- 2 # Copy to directory as partitioned files query IT -COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' (format parquet, 'parquet.compression' 'zstd(10)', partition_by 'col2'); +COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' STORED AS parquet PARTITIONED BY (col2) OPTIONS ('format.compression' 'zstd(10)'); ---- 2 @@ -54,8 +54,8 @@ select * from validate_partitioned_parquet_bar order by col1; # Copy to directory as partitioned files query ITT -COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' -(format parquet, partition_by 'column2, column3', 'parquet.compression' 'zstd(10)'); +COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' STORED AS parquet PARTITIONED BY (column2, column3) +OPTIONS ('format.compression' 'zstd(10)'); ---- 3 @@ -82,8 +82,8 @@ select * from validate_partitioned_parquet_a_x order by column1; # Copy to directory as partitioned files query TTT -COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' -(format parquet, 'parquet.compression' 'zstd(10)', partition_by 'column1, column3'); +COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' STORED AS parquet PARTITIONED BY (column1, column3) +OPTIONS ('format.compression' 'zstd(10)'); ---- 3 @@ -111,49 +111,52 @@ a statement ok create table test ("'test'" varchar, "'test2'" varchar, "'test3'" varchar); -query TTT -insert into test VALUES ('a', 'x', 'aa'), ('b','y', 'bb'), ('c', 'z', 'cc') ----- -3 - -query T -select "'test'" from test ----- -a -b -c - -# Note to place a single ' inside of a literal string escape by putting two '' -query TTT -copy test to 'test_files/scratch/copy/escape_quote' (format csv, partition_by '''test2'',''test3''') ----- -3 - -statement ok -CREATE EXTERNAL TABLE validate_partitioned_escape_quote STORED AS CSV -LOCATION 'test_files/scratch/copy/escape_quote/' PARTITIONED BY ("'test2'", "'test3'"); - +## Until the partition by parsing uses ColumnDef, this test is meaningless since it becomes an overfit. Even in +## CREATE EXTERNAL TABLE, there is a schema mismatch, this should be an issue. +# +#query TTT +#insert into test VALUES ('a', 'x', 'aa'), ('b','y', 'bb'), ('c', 'z', 'cc') +#---- +#3 +# +#query T +#select "'test'" from test +#---- +#a +#b +#c +# +# # Note to place a single ' inside of a literal string escape by putting two '' +#query TTT +#copy test to 'test_files/scratch/copy/escape_quote' STORED AS CSV; +#---- +#3 +# +#statement ok +#CREATE EXTERNAL TABLE validate_partitioned_escape_quote STORED AS CSV +#LOCATION 'test_files/scratch/copy/escape_quote/' PARTITIONED BY ("'test2'", "'test3'"); +# # This triggers a panic (index out of bounds) # https://github.com/apache/arrow-datafusion/issues/9269 #query #select * from validate_partitioned_escape_quote; query TT -EXPLAIN COPY source_table TO 'test_files/scratch/copy/table/' (format parquet, 'parquet.compression' 'zstd(10)'); +EXPLAIN COPY source_table TO 'test_files/scratch/copy/table/' STORED AS PARQUET OPTIONS ('format.compression' 'zstd(10)'); ---- logical_plan -CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: (parquet.compression zstd(10)) +CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: (format.compression zstd(10)) --TableScan: source_table projection=[col1, col2] physical_plan FileSinkExec: sink=ParquetSink(file_groups=[]) --MemoryExec: partitions=1, partition_sizes=[1] # Error case -query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! +query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! Use STORED AS to define file format. EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' query TT -EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' (format parquet) +EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' STORED AS PARQUET ---- logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: () @@ -164,7 +167,7 @@ FileSinkExec: sink=ParquetSink(file_groups=[]) # Copy more files to directory via query query IT -COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' (format parquet); +COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' STORED AS PARQUET; ---- 4 @@ -185,7 +188,7 @@ select * from validate_parquet; query ? copy (values (struct(timestamp '2021-01-01 01:00:01', 1)), (struct(timestamp '2022-01-01 01:00:01', 2)), (struct(timestamp '2023-01-03 01:00:01', 3)), (struct(timestamp '2024-01-01 01:00:01', 4))) -to 'test_files/scratch/copy/table_nested2/' (format parquet); +to 'test_files/scratch/copy/table_nested2/' STORED AS PARQUET; ---- 4 @@ -204,7 +207,7 @@ query ?? COPY (values (struct ('foo', (struct ('foo', make_array(struct('a',1), struct('b',2))))), make_array(timestamp '2023-01-01 01:00:01',timestamp '2023-01-01 01:00:01')), (struct('bar', (struct ('foo', make_array(struct('aa',10), struct('bb',20))))), make_array(timestamp '2024-01-01 01:00:01', timestamp '2024-01-01 01:00:01'))) -to 'test_files/scratch/copy/table_nested/' (format parquet); +to 'test_files/scratch/copy/table_nested/' STORED AS PARQUET; ---- 2 @@ -221,7 +224,7 @@ select * from validate_parquet_nested; query ? copy (values ([struct('foo', 1), struct('bar', 2)])) to 'test_files/scratch/copy/array_of_struct/' -(format parquet); +STORED AS PARQUET; ---- 1 @@ -236,8 +239,7 @@ select * from validate_array_of_struct; query ? copy (values (struct('foo', [1,2,3], struct('bar', [2,3,4])))) -to 'test_files/scratch/copy/struct_with_array/' -(format parquet); +to 'test_files/scratch/copy/struct_with_array/' STORED AS PARQUET; ---- 1 @@ -255,31 +257,32 @@ select * from validate_struct_with_array; query IT COPY source_table TO 'test_files/scratch/copy/table_with_options/' -(format parquet, -'parquet.compression' snappy, -'parquet.compression::col1' 'zstd(5)', -'parquet.compression::col2' snappy, -'parquet.max_row_group_size' 12345, -'parquet.data_pagesize_limit' 1234, -'parquet.write_batch_size' 1234, -'parquet.writer_version' 2.0, -'parquet.dictionary_page_size_limit' 123, -'parquet.created_by' 'DF copy.slt', -'parquet.column_index_truncate_length' 123, -'parquet.data_page_row_count_limit' 1234, -'parquet.bloom_filter_enabled' true, -'parquet.bloom_filter_enabled::col1' false, -'parquet.bloom_filter_fpp::col2' 0.456, -'parquet.bloom_filter_ndv::col2' 456, -'parquet.encoding' plain, -'parquet.encoding::col1' DELTA_BINARY_PACKED, -'parquet.dictionary_enabled::col2' true, -'parquet.dictionary_enabled' false, -'parquet.statistics_enabled' page, -'parquet.statistics_enabled::col2' none, -'parquet.max_statistics_size' 123, -'parquet.bloom_filter_fpp' 0.001, -'parquet.bloom_filter_ndv' 100 +STORED AS PARQUET +OPTIONS ( +'format.compression' snappy, +'format.compression::col1' 'zstd(5)', +'format.compression::col2' snappy, +'format.max_row_group_size' 12345, +'format.data_pagesize_limit' 1234, +'format.write_batch_size' 1234, +'format.writer_version' 2.0, +'format.dictionary_page_size_limit' 123, +'format.created_by' 'DF copy.slt', +'format.column_index_truncate_length' 123, +'format.data_page_row_count_limit' 1234, +'format.bloom_filter_enabled' true, +'format.bloom_filter_enabled::col1' false, +'format.bloom_filter_fpp::col2' 0.456, +'format.bloom_filter_ndv::col2' 456, +'format.encoding' plain, +'format.encoding::col1' DELTA_BINARY_PACKED, +'format.dictionary_enabled::col2' true, +'format.dictionary_enabled' false, +'format.statistics_enabled' page, +'format.statistics_enabled::col2' none, +'format.max_statistics_size' 123, +'format.bloom_filter_fpp' 0.001, +'format.bloom_filter_ndv' 100 ) ---- 2 @@ -312,7 +315,7 @@ select * from validate_parquet_single; # copy from table to folder of compressed json files query IT -COPY source_table to 'test_files/scratch/copy/table_json_gz' (format json, 'json.compression' gzip); +COPY source_table to 'test_files/scratch/copy/table_json_gz' STORED AS JSON OPTIONS ('format.compression' gzip); ---- 2 @@ -328,7 +331,7 @@ select * from validate_json_gz; # copy from table to folder of compressed csv files query IT -COPY source_table to 'test_files/scratch/copy/table_csv' (format csv, 'csv.has_header' false, 'csv.compression' gzip); +COPY source_table to 'test_files/scratch/copy/table_csv' STORED AS CSV OPTIONS ('format.has_header' false, 'format.compression' gzip); ---- 2 @@ -360,7 +363,7 @@ select * from validate_single_csv; # Copy from table to folder of json query IT -COPY source_table to 'test_files/scratch/copy/table_json' (format json); +COPY source_table to 'test_files/scratch/copy/table_json' STORED AS JSON; ---- 2 @@ -376,7 +379,7 @@ select * from validate_json; # Copy from table to single json file query IT -COPY source_table to 'test_files/scratch/copy/table.json'; +COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON ; ---- 2 @@ -394,12 +397,12 @@ select * from validate_single_json; query IT COPY source_table to 'test_files/scratch/copy/table_csv_with_options' -(format csv, -'csv.has_header' false, -'csv.compression' uncompressed, -'csv.datetime_format' '%FT%H:%M:%S.%9f', -'csv.delimiter' ';', -'csv.null_value' 'NULLVAL'); +STORED AS CSV OPTIONS ( +'format.has_header' false, +'format.compression' uncompressed, +'format.datetime_format' '%FT%H:%M:%S.%9f', +'format.delimiter' ';', +'format.null_value' 'NULLVAL'); ---- 2 @@ -417,7 +420,7 @@ select * from validate_csv_with_options; # Copy from table to single arrow file query IT -COPY source_table to 'test_files/scratch/copy/table.arrow'; +COPY source_table to 'test_files/scratch/copy/table.arrow' STORED AS ARROW; ---- 2 @@ -437,7 +440,7 @@ select * from validate_arrow_file; query T? COPY (values ('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) -to 'test_files/scratch/copy/table_dict.arrow'; +to 'test_files/scratch/copy/table_dict.arrow' STORED AS ARROW; ---- 2 @@ -456,7 +459,7 @@ d bar # Copy from table to folder of json query IT -COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow); +COPY source_table to 'test_files/scratch/copy/table_arrow' STORED AS ARROW; ---- 2 @@ -475,12 +478,12 @@ select * from validate_arrow; # Copy from table with options query error DataFusion error: Invalid or Unsupported Configuration: Config value "row_group_size" not found on JsonOptions -COPY source_table to 'test_files/scratch/copy/table.json' ('json.row_group_size' 55); +COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON OPTIONS ('format.row_group_size' 55); # Incomplete statement query error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) COPY (select col2, sum(col1) from source_table # Copy from table with non literal -query error DataFusion error: SQL error: ParserError\("Expected ',' or '\)' after option definition, found: \+"\) +query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 3b85dd9e986f..c4a26a5e227d 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -101,8 +101,8 @@ statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOB CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; # Conflicting options -statement error DataFusion error: Invalid or Unsupported Configuration: Key "parquet.column_index_truncate_length" is not applicable for CSV format +statement error DataFusion error: Invalid or Unsupported Configuration: Config value "column_index_truncate_length" not found on CsvOptions CREATE EXTERNAL TABLE csv_table (column1 int) STORED AS CSV LOCATION 'foo.csv' -OPTIONS ('csv.delimiter' ';', 'parquet.column_index_truncate_length' '123') +OPTIONS ('format.delimiter' ';', 'format.column_index_truncate_length' '123') diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 7b299c0cf143..ab6847afb6a5 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -23,7 +23,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.quote' '~') +OPTIONS ('format.quote' '~') LOCATION '../core/tests/data/quote.csv'; statement ok @@ -33,7 +33,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.escape' '\') +OPTIONS ('format.escape' '\') LOCATION '../core/tests/data/escape.csv'; query TT @@ -71,7 +71,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.escape' '"') +OPTIONS ('format.escape' '"') LOCATION '../core/tests/data/escape.csv'; # TODO: Validate this with better data. @@ -117,14 +117,14 @@ CREATE TABLE src_table_2 ( query ITII COPY src_table_1 TO 'test_files/scratch/csv_files/csv_partitions/1.csv' -(FORMAT CSV); +STORED AS CSV; ---- 4 query ITII COPY src_table_2 TO 'test_files/scratch/csv_files/csv_partitions/2.csv' -(FORMAT CSV); +STORED AS CSV; ---- 4 diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 3d9f8ff3ad2c..869462b4722a 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4506,28 +4506,28 @@ CREATE TABLE src_table ( query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/0.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/1.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/2.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/3.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index b7cd1243cb0f..3cc52666d533 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -45,7 +45,7 @@ CREATE TABLE src_table ( query ITID COPY (SELECT * FROM src_table LIMIT 3) TO 'test_files/scratch/parquet/test_table/0.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -53,7 +53,7 @@ TO 'test_files/scratch/parquet/test_table/0.parquet' query ITID COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) TO 'test_files/scratch/parquet/test_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -128,7 +128,7 @@ SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] query ITID COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -281,7 +281,7 @@ LIMIT 10; query ITID COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 391a6739b060..594c52f12d75 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -25,7 +25,7 @@ set datafusion.execution.target_partitions = 4; statement ok COPY (VALUES (1, 2), (2, 5), (3, 2), (4, 5), (5, 0)) TO 'test_files/scratch/repartition/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE parquet_table(column1 int, column2 int) diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 15fe670a454c..fe0f6c1e8139 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -35,7 +35,7 @@ set datafusion.optimizer.repartition_file_min_size = 1; # Note filename 2.parquet to test sorting (on local file systems it is often listed before 1.parquet) statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE parquet_table(column1 int) @@ -86,7 +86,7 @@ set datafusion.optimizer.enable_round_robin_repartition = true; # create a second parquet file statement ok COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ## Still expect to see the scan read the file as "4" groups with even sizes. One group should read ## parts of both files. @@ -158,7 +158,7 @@ DROP TABLE parquet_table_with_order; # create a single csv file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' -(FORMAT csv, 'csv.has_header' true); +STORED AS CSV WITH HEADER ROW; statement ok CREATE EXTERNAL TABLE csv_table(column1 int) @@ -202,7 +202,7 @@ DROP TABLE csv_table; # create a single json file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/json_table/1.json' -(FORMAT json); +STORED AS JSON; statement ok CREATE EXTERNAL TABLE json_table (column1 int) diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt index aee0e97edc1e..5572c4a5ffef 100644 --- a/datafusion/sqllogictest/test_files/schema_evolution.slt +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -31,7 +31,7 @@ COPY ( SELECT column1 as a, column2 as b FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File2 has only b @@ -40,7 +40,7 @@ COPY ( SELECT column1 as b FROM ( VALUES (10) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File3 has a column from 'z' which does not appear in the table # but also values from a which do appear in the table @@ -49,7 +49,7 @@ COPY ( SELECT column1 as z, column2 as a FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File4 has data for b and a (reversed) and d statement ok @@ -57,7 +57,7 @@ COPY ( SELECT column1 as b, column2 as a, column3 as c FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # The logical distribution of `a`, `b` and `c` in the files is like this: # diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 405e77a21b26..b9614bb8f929 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -49,7 +49,7 @@ Copy the contents of `source_table` to one or more Parquet formatted files in the `dir_name` directory: ```sql -> COPY source_table TO 'dir_name' (FORMAT parquet); +> COPY source_table TO 'dir_name' STORED AS PARQUET; +-------+ | count | +-------+ From fa7ca27c15328247dbf98b2f8773c19398b8a745 Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:53:46 -0500 Subject: [PATCH 10/35] Support "A column is known to be entirely NULL" in `PruningPredicate` (#9223) * feat: add row_counts() to PruningStatistics trait * chore: remove comments * feat(pruning): add predicate rewrite for `CASE WHEN x_null_count = x_row_count THEN false ELSE ... END` * chore: clippy and update pruning predicates in tests * chore(pruning): fix data type and column expression for null and row counts chore: fix pruning_predicate in slt tests chore: clippy * doc: add examples in doc * chore: update comments * docs: use feedback Co-authored-by: Andrew Lamb docs: take more feedback * test: add test * docs: update comments * docs: update comments to put rewritten predicate first --- datafusion-examples/examples/pruning.rs | 5 + .../datasource/physical_plan/parquet/mod.rs | 2 +- .../physical_plan/parquet/page_filter.rs | 4 + .../physical_plan/parquet/row_groups.rs | 8 + .../core/src/physical_optimizer/pruning.rs | 516 ++++++++++++++++-- .../test_files/repartition_scan.slt | 8 +- 6 files changed, 492 insertions(+), 51 deletions(-) diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 1d84fc2d1e0a..3fa35049a8da 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -163,6 +163,11 @@ impl PruningStatistics for MyCatalog { None } + fn row_counts(&self, _column: &Column) -> Option { + // In this example, we know nothing about the number of rows in each file + None + } + fn contained( &self, _column: &Column, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 2cfbb578da66..a2e645cf3e72 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1870,7 +1870,7 @@ mod tests { assert_contains!( &display, - "pruning_predicate=c1_min@0 != bar OR bar != c1_max@1" + "pruning_predicate=CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != bar OR bar != c1_max@1 END" ); assert_contains!(&display, r#"predicate=c1@0 != bar"#); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 064a8e1fff33..c7706f3458d0 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -547,6 +547,10 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { } } + fn row_counts(&self, _column: &datafusion_common::Column) -> Option { + None + } + fn contained( &self, _column: &datafusion_common::Column, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 1a84f52a33fd..a0bb5ab71204 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -199,6 +199,10 @@ impl PruningStatistics for BloomFilterStatistics { None } + fn row_counts(&self, _column: &Column) -> Option { + None + } + /// Use bloom filters to determine if we are sure this column can not /// possibly contain `values` /// @@ -332,6 +336,10 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { scalar.to_array().ok() } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( &self, _column: &Column, diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index d2126f90eca9..80bb5ad42e81 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -53,7 +53,7 @@ use log::trace; /// /// 1. Minimum and maximum values for columns /// -/// 2. Null counts for columns +/// 2. Null counts and row counts for columns /// /// 3. Whether the values in a column are contained in a set of literals /// @@ -100,7 +100,8 @@ pub trait PruningStatistics { /// these statistics. /// /// This value corresponds to the size of the [`ArrayRef`] returned by - /// [`Self::min_values`], [`Self::max_values`], and [`Self::null_counts`]. + /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], + /// and [`Self::row_counts`]. fn num_containers(&self) -> usize; /// Return the number of null values for the named column as an @@ -111,6 +112,14 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + /// Return the number of rows for the named column in each container + /// as an `Option`. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn row_counts(&self, column: &Column) -> Option; + /// Returns [`BooleanArray`] where each row represents information known /// about specific literal `values` in a column. /// @@ -268,7 +277,7 @@ pub trait PruningStatistics { /// 3. [`PruningStatistics`] that provides information about columns in that /// schema, for multiple “containers”. For each column in each container, it /// provides optional information on contained values, min_values, max_values, -/// and null_counts counts. +/// null_counts counts, and row_counts counts. /// /// **Outputs**: /// A (non null) boolean value for each container: @@ -306,17 +315,23 @@ pub trait PruningStatistics { /// * `false`: there are no rows that could possibly match the predicate, /// **PRUNES** the container /// -/// For example, given a column `x`, the `x_min` and `x_max` and `x_null_count` -/// represent the minimum and maximum values, and the null count of column `x`, -/// provided by the `PruningStatistics`. Here are some examples of the rewritten -/// predicates: +/// For example, given a column `x`, the `x_min`, `x_max`, `x_null_count`, and +/// `x_row_count` represent the minimum and maximum values, the null count of +/// column `x`, and the row count of column `x`, provided by the `PruningStatistics`. +/// `x_null_count` and `x_row_count` are used to handle the case where the column `x` +/// is known to be all `NULL`s. Note this is different from knowing nothing about +/// the column `x`, which confusingly is encoded by returning `NULL` for the min/max +/// values from [`PruningStatistics::max_values`] and [`PruningStatistics::min_values`]. +/// +/// Here are some examples of the rewritten predicates: /// /// Original Predicate | Rewritten Predicate /// ------------------ | -------------------- -/// `x = 5` | `x_min <= 5 AND 5 <= x_max` -/// `x < 5` | `x_max < 5` -/// `x = 5 AND y = 10` | `x_min <= 5 AND 5 <= x_max AND y_min <= 10 AND 10 <= y_max` -/// `x IS NULL` | `x_null_count > 0` +/// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END` +/// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END` +/// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END` +/// `x IS NULL` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_null_count > 0 END` +/// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END` /// /// ## Predicate Evaluation /// The PruningPredicate works in two passes @@ -326,28 +341,47 @@ pub trait PruningStatistics { /// LiteralGuarantees are not satisfied /// /// **Second Pass**: Evaluates the rewritten expression using the -/// min/max/null_counts values for each column for each container. For any +/// min/max/null_counts/row_counts values for each column for each container. For any /// container that this expression evaluates to `false`, it rules out those /// containers. /// -/// For example, given the predicate, `x = 5 AND y = 10`, if we know `x` is -/// between `1 and 100` and we know that `y` is between `4` and `7`, the input -/// statistics might look like +/// +/// ### Example 1 +/// +/// Given the predicate, `x = 5 AND y = 10`, the rewritten predicate would look like: +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 5 AND 5 <= x_max +/// END +/// AND +/// CASE +/// WHEN y_null_count = y_row_count THEN false +/// ELSE y_min <= 10 AND 10 <= y_max +/// END +/// ``` +/// +/// If we know that for a given container, `x` is between `1 and 100` and we know that +/// `y` is between `4` and `7`, we know nothing about the null count and row count of +/// `x` and `y`, the input statistics might look like: /// /// Column | Value /// -------- | ----- /// `x_min` | `1` /// `x_max` | `100` +/// `x_null_count` | `null` +/// `x_row_count` | `null` /// `y_min` | `4` /// `y_max` | `7` +/// `y_null_count` | `null` +/// `y_row_count` | `null` /// -/// The rewritten predicate would look like -/// -/// `x_min <= 5 AND 5 <= x_max AND y_min <= 10 AND 10 <= y_max` -/// -/// When these values are substituted in to the rewritten predicate and +/// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// +/// * `CASE WHEN null = null THEN false ELSE 1 <= 5 AND 5 <= 100 END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` +/// * `null = null` is `null` which is not true, so the `CASE` expression will use the `ELSE` clause /// * `1 <= 5 AND 5 <= 100 AND 4 <= 10 AND 10 <= 7` /// * `true AND true AND true AND false` /// * `false` @@ -364,6 +398,52 @@ pub trait PruningStatistics { /// more analysis, for example by actually reading the data and evaluating the /// predicate row by row. /// +/// ### Example 2 +/// +/// Given the same predicate, `x = 5 AND y = 10`, the rewritten predicate would +/// look like the same as example 1: +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 5 AND 5 <= x_max +/// END +/// AND +/// CASE +/// WHEN y_null_count = y_row_count THEN false +/// ELSE y_min <= 10 AND 10 <= y_max +/// END +/// ``` +/// +/// If we know that for another given container, `x_min` is NULL and `x_max` is +/// NULL (the min/max values are unknown), `x_null_count` is `100` and `x_row_count` +/// is `100`; we know that `y` is between `4` and `7`, but we know nothing about +/// the null count and row count of `y`. The input statistics might look like: +/// +/// Column | Value +/// -------- | ----- +/// `x_min` | `null` +/// `x_max` | `null` +/// `x_null_count` | `100` +/// `x_row_count` | `100` +/// `y_min` | `4` +/// `y_max` | `7` +/// `y_null_count` | `null` +/// `y_row_count` | `null` +/// +/// When these statistics values are substituted in to the rewritten predicate and +/// simplified, the result is `false`: +/// +/// * `CASE WHEN 100 = 100 THEN false ELSE null <= 5 AND 5 <= null END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` +/// * Since `100 = 100` is `true`, the `CASE` expression will use the `THEN` clause, i.e. `false` +/// * The other `CASE` expression will use the `ELSE` clause, i.e. `4 <= 10 AND 10 <= 7` +/// * `false AND true` +/// * `false` +/// +/// Returning `false` means the container can be pruned, which matches the +/// intuition that `x = 5 AND y = 10` can’t be true for all values in `x` +/// are known to be NULL. +/// /// # Related Work /// /// [`PruningPredicate`] implements the type of min/max pruning described in @@ -744,6 +824,22 @@ impl RequiredColumns { "null_count", ) } + + /// rewrite col --> col_row_count + fn row_count_column_expr( + &mut self, + column: &phys_expr::Column, + column_expr: &Arc, + field: &Field, + ) -> Result> { + self.stat_column_expr( + column, + column_expr, + field, + StatisticsType::RowCount, + "row_count", + ) + } } impl From> for RequiredColumns { @@ -794,6 +890,7 @@ fn build_statistics_record_batch( StatisticsType::Min => statistics.min_values(&column), StatisticsType::Max => statistics.max_values(&column), StatisticsType::NullCount => statistics.null_counts(&column), + StatisticsType::RowCount => statistics.row_counts(&column), }; let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); @@ -903,6 +1000,46 @@ impl<'a> PruningExpressionBuilder<'a> { self.required_columns .max_column_expr(&self.column, &self.column_expr, self.field) } + + /// This function is to simply retune the `null_count` physical expression no matter what the + /// predicate expression is + /// + /// i.e., x > 5 => x_null_count, + /// cast(x as int) < 10 => x_null_count, + /// try_cast(x as float) < 10.0 => x_null_count + fn null_count_column_expr(&mut self) -> Result> { + // Retune to [`phys_expr::Column`] + let column_expr = Arc::new(self.column.clone()) as _; + + // null_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) + let null_count_field = &Field::new(self.field.name(), DataType::UInt64, true); + + self.required_columns.null_count_column_expr( + &self.column, + &column_expr, + null_count_field, + ) + } + + /// This function is to simply retune the `row_count` physical expression no matter what the + /// predicate expression is + /// + /// i.e., x > 5 => x_row_count, + /// cast(x as int) < 10 => x_row_count, + /// try_cast(x as float) < 10.0 => x_row_count + fn row_count_column_expr(&mut self) -> Result> { + // Retune to [`phys_expr::Column`] + let column_expr = Arc::new(self.column.clone()) as _; + + // row_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) + let row_count_field = &Field::new(self.field.name(), DataType::UInt64, true); + + self.required_columns.row_count_column_expr( + &self.column, + &column_expr, + row_count_field, + ) + } } /// This function is designed to rewrite the column_expr to @@ -1320,14 +1457,56 @@ fn build_statistics_expr( ); } }; + let statistics_expr = wrap_case_expr(statistics_expr, expr_builder)?; Ok(statistics_expr) } +/// Wrap the statistics expression in a case expression. +/// This is necessary to handle the case where the column is known +/// to be all nulls. +/// +/// For example: +/// +/// `x_min <= 10 AND 10 <= x_max` +/// +/// will become +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 10 AND 10 <= x_max +/// END +/// ```` +/// +/// If the column is known to be all nulls, then the expression +/// `x_null_count = x_row_count` will be true, which will cause the +/// case expression to return false. Therefore, prune out the container. +fn wrap_case_expr( + statistics_expr: Arc, + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + // x_null_count = x_row_count + let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( + expr_builder.null_count_column_expr()?, + Operator::Eq, + expr_builder.row_count_column_expr()?, + )); + let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))); + + // CASE WHEN x_null_count = x_row_count THEN false ELSE END + Ok(Arc::new(phys_expr::CaseExpr::try_new( + None, + vec![(when_null_count_eq_row_count, then)], + Some(statistics_expr), + )?)) +} + #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) enum StatisticsType { Min, Max, NullCount, + RowCount, } #[cfg(test)] @@ -1361,6 +1540,7 @@ mod tests { max: Option, /// Optional values null_counts: Option, + row_counts: Option, /// Optional known values (e.g. mimic a bloom filter) /// (value, contained) /// If present, all BooleanArrays must be the same size as min/max @@ -1440,6 +1620,10 @@ mod tests { self.null_counts.clone() } + fn row_counts(&self) -> Option { + self.row_counts.clone() + } + /// return an iterator over all arrays in this statistics fn arrays(&self) -> Vec { let contained_arrays = self @@ -1451,6 +1635,7 @@ mod tests { self.min.as_ref().cloned(), self.max.as_ref().cloned(), self.null_counts.as_ref().cloned(), + self.row_counts.as_ref().cloned(), ] .into_iter() .flatten() @@ -1509,6 +1694,20 @@ mod tests { self } + /// Add row counts. There must be the same number of row counts as + /// there are containers + fn with_row_counts( + mut self, + counts: impl IntoIterator>, + ) -> Self { + let row_counts: ArrayRef = + Arc::new(counts.into_iter().collect::()); + + self.assert_invariants(); + self.row_counts = Some(row_counts); + self + } + /// Add contained information. pub fn with_contained( mut self, @@ -1576,6 +1775,28 @@ mod tests { self } + /// Add row counts for the specified columm. + /// There must be the same number of row counts as + /// there are containers + fn with_row_counts( + mut self, + name: impl Into, + counts: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_row_counts(counts); + + // put stats back in + self.stats.insert(col, container_stats); + self + } + /// Add contained information for the specified columm. fn with_contained( mut self, @@ -1628,6 +1849,13 @@ mod tests { .unwrap_or(None) } + fn row_counts(&self, column: &Column) -> Option { + self.stats + .get(column) + .map(|container_stats| container_stats.row_counts()) + .unwrap_or(None) + } + fn contained( &self, column: &Column, @@ -1663,6 +1891,10 @@ mod tests { None } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( &self, _column: &Column, @@ -1853,7 +2085,7 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1"; + let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 END"; // test column on the left let expr = col("c1").eq(lit(1)); @@ -1873,7 +2105,7 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 != 1 OR 1 != c1_max@1"; + let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != 1 OR 1 != c1_max@1 END"; // test column on the left let expr = col("c1").not_eq(lit(1)); @@ -1893,7 +2125,8 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max@0 > 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 > 1 END"; // test column on the left let expr = col("c1").gt(lit(1)); @@ -1913,7 +2146,7 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max@0 >= 1"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 >= 1 END"; // test column on the left let expr = col("c1").gt_eq(lit(1)); @@ -1932,7 +2165,8 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 < 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; // test column on the left let expr = col("c1").lt(lit(1)); @@ -1952,7 +2186,7 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 <= 1"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 <= 1 END"; // test column on the left let expr = col("c1").lt_eq(lit(1)); @@ -1977,7 +2211,8 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = "c1_min@0 < 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2043,7 +2278,7 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "c1_min@0 < true"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < true END"; // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated @@ -2066,7 +2301,21 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "c1_min@0 < 1 AND (c2_min@1 <= 2 AND 2 <= c2_max@2 OR c2_min@1 <= 3 AND 3 <= c2_max@2)"; + let expected_expr = "\ + CASE \ + WHEN c1_null_count@1 = c1_row_count@2 THEN false \ + ELSE c1_min@0 < 1 \ + END \ + AND (\ + CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@3 <= 2 AND 2 <= c2_max@4 \ + END \ + OR CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@3 <= 3 AND 3 <= c2_max@4 \ + END\ + )"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2080,10 +2329,30 @@ mod tests { c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); + // c1 < 1 should add c1_null_count + let c1_null_count_field = Field::new("c1_null_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[1], + ( + phys_expr::Column::new("c1", 0), + StatisticsType::NullCount, + c1_null_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); + // c1 < 1 should add c1_row_count + let c1_row_count_field = Field::new("c1_row_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[2], + ( + phys_expr::Column::new("c1", 0), + StatisticsType::RowCount, + c1_row_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( - required_columns.columns[1], + required_columns.columns[3], ( phys_expr::Column::new("c2", 1), StatisticsType::Min, @@ -2092,15 +2361,35 @@ mod tests { ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( - required_columns.columns[2], + required_columns.columns[4], ( phys_expr::Column::new("c2", 1), StatisticsType::Max, c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); + // c2 = 2 should add c2_null_count + let c2_null_count_field = Field::new("c2_null_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[5], + ( + phys_expr::Column::new("c2", 1), + StatisticsType::NullCount, + c2_null_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); + // c2 = 2 should add c2_row_count + let c2_row_count_field = Field::new("c2_row_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[6], + ( + phys_expr::Column::new("c2", 1), + StatisticsType::RowCount, + c2_row_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); // c2 = 3 shouldn't add any new statistics fields - assert_eq!(required_columns.columns.len(), 3); + assert_eq!(required_columns.columns.len(), 7); Ok(()) } @@ -2117,7 +2406,18 @@ mod tests { vec![lit(1), lit(2), lit(3)], false, )); - let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 3 AND 3 <= c1_max@1 \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2153,9 +2453,19 @@ mod tests { vec![lit(1), lit(2), lit(3)], true, )); - let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ - AND (c1_min@0 != 2 OR 2 != c1_max@1) \ - AND (c1_min@0 != 3 OR 3 != c1_max@1)"; + let expected_expr = "\ + CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 1 OR 1 != c1_max@1 \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 2 OR 2 != c1_max@1 \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 3 OR 3 != c1_max@1 \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2201,7 +2511,24 @@ mod tests { // test c1 in(1, 2) and c2 BETWEEN 4 AND 5 let expr3 = expr1.and(expr2); - let expected_expr = "(c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_max@2 >= 4 AND c2_min@3 <= 5"; + let expected_expr = "\ + (\ + CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ + END\ + ) AND CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_max@4 >= 4 \ + END \ + AND CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@7 <= 5 \ + END"; let predicate_expr = test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2228,9 +2555,12 @@ mod tests { #[test] fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ + END"; + // test cast(c1 as int64) = 1 // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = @@ -2243,7 +2573,10 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; + let expected_expr = "CASE \ + WHEN c1_null_count@1 = c1_row_count@2 THEN false \ + ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \ + END"; // test column on the left let expr = @@ -2275,7 +2608,18 @@ mod tests { ], false, )); - let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2289,10 +2633,18 @@ mod tests { ], true, )); - let expected_expr = - "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ - AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ - AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2819,7 +3171,7 @@ mod tests { let expected_ret = &[false, true, true, true, false]; prune_with_expr( - // i IS NULL, with actual null statistcs + // i IS NULL, with actual null statistics col("i").is_null(), &schema, &statistics, @@ -2827,6 +3179,78 @@ mod tests { ); } + #[test] + fn prune_int32_column_is_known_all_null() { + let (schema, statistics) = int32_setup(); + + // Expression "i < 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass (not keep) + // i [-11, -1] ==> all rows must pass (must keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (not keep) + let expected_ret = &[true, false, true, true, false]; + + prune_with_expr( + // i < 0 + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // provide row counts for each column + let statistics = statistics.with_row_counts( + "i", + vec![ + Some(10), // 10 rows of data + Some(9), // 9 rows of data + None, // unknown row counts + Some(4), + Some(10), + ], + ); + + // pruning result is still the same if we only know row counts + prune_with_expr( + // i < 0, with only row counts statistics + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // provide null counts for each column + let statistics = statistics.with_null_counts( + "i", + vec![ + Some(0), // no nulls + Some(1), // 1 null + None, // unknown nulls + Some(4), // 4 nulls, which is the same as the row counts, i.e. this column is all null (don't keep) + Some(0), // 0 nulls (max=null too which means no known max) + ], + ); + + // Expression "i < 0" with actual null and row counts statistics + // col | min, max | row counts | null counts | + // ----+--------------+------------+-------------+ + // i | [-5, 5] | 10 | 0 | ==> Some rows could pass (must keep) + // i | [1, 11] | 9 | 1 | ==> No rows can pass (not keep) + // i | [-11,-1] | Unknown | Unknown | ==> All rows must pass (must keep) + // i | [NULL, NULL] | 4 | 4 | ==> The column is all null (not keep) + // i | [1, NULL] | 10 | 0 | ==> No rows can pass (not keep) + let expected_ret = &[true, false, true, false, false]; + + prune_with_expr( + // i < 0, with actual null and row counts statistics + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + } + #[test] fn prune_cast_column_scalar() { // The data type of column i is INT32 diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index fe0f6c1e8139..f9699a5fda8f 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --SortExec: expr=[column1@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=8192 ------FilterExec: column1@0 != 42 ---------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # Cleanup statement ok From b0b329ba39403b9e87156d6f9b8c5464dc6d2480 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 19 Mar 2024 06:57:52 +0800 Subject: [PATCH 11/35] Suppress self update for windows CI runner (#9661) * suppress self update for window Signed-off-by: jayzhan211 * Update .github/actions/setup-windows-builder/action.yaml --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- .github/actions/setup-windows-builder/action.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/setup-windows-builder/action.yaml b/.github/actions/setup-windows-builder/action.yaml index 9ab5c4a8b1bb..a26a34a3db93 100644 --- a/.github/actions/setup-windows-builder/action.yaml +++ b/.github/actions/setup-windows-builder/action.yaml @@ -38,8 +38,8 @@ runs: - name: Setup Rust toolchain shell: bash run: | - rustup update stable - rustup toolchain install stable + # Avoid self update to avoid CI failures: https://github.com/apache/arrow-datafusion/issues/9653 + rustup toolchain install stable --no-self-update rustup default stable rustup component add rustfmt - name: Configure rust runtime env From 8438d2b1ea67fda64955839fb4bd4ed88b861ade Mon Sep 17 00:00:00 2001 From: Suriya Kandaswamy Date: Tue, 19 Mar 2024 10:14:42 -0400 Subject: [PATCH 12/35] add schema to SQL ast builder (#9624) * add schema to ast builder * add schema test --- datafusion/sql/src/unparser/plan.rs | 9 ++++++--- datafusion/sql/tests/sql_integration.rs | 9 +++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e1f5135efda9..c9b0a8a04c7e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -124,9 +124,12 @@ impl Unparser<'_> { match plan { LogicalPlan::TableScan(scan) => { let mut builder = TableRelationBuilder::default(); - builder.name(ast::ObjectName(vec![ - self.new_ident(scan.table_name.table().to_string()) - ])); + let mut table_parts = vec![]; + if let Some(schema_name) = scan.table_name.schema() { + table_parts.push(self.new_ident(schema_name.to_string())); + } + table_parts.push(self.new_ident(scan.table_name.table().to_string())); + builder.name(ast::ObjectName(table_parts)); relation.table(builder); Ok(()) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6d335f1f8fc9..47638e58ff00 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -41,6 +41,15 @@ use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; +#[test] +fn test_schema_support() { + quick_test( + "SELECT * FROM s1.test", + "Projection: s1.test.t_date32, s1.test.t_date64\ + \n TableScan: s1.test", + ); +} + #[test] fn parse_decimals() { let test_data = [ From 9b098eef6f7d8b6d1162ccbdc9053f8e1cb999d4 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Tue, 19 Mar 2024 15:20:01 +0100 Subject: [PATCH 13/35] Add tests for row group pruning on strings (#9642) --- datafusion/core/tests/parquet/mod.rs | 112 +++++++++++++++++- .../core/tests/parquet/row_group_pruning.rs | 102 ++++++++++++++++ 2 files changed, 211 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index c60780919489..3fe51288e79a 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,9 +19,9 @@ use arrow::array::Decimal128Array; use arrow::{ array::{ - Array, ArrayRef, Date32Array, Date64Array, Float64Array, Int32Array, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Float64Array, Int32Array, StringArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -70,6 +70,7 @@ enum Scenario { DecimalBloomFilterInt64, DecimalLargePrecision, DecimalLargePrecisionBloomFilter, + ByteArray, PeriodsInColumnNames, } @@ -506,6 +507,51 @@ fn make_date_batch(offset: Duration) -> RecordBatch { .unwrap() } +/// returns a batch with two columns (note "service.name" is the name +/// of the column. It is *not* a table named service.name +/// +/// name | service.name +fn make_bytearray_batch( + name: &str, + string_values: Vec<&str>, + binary_values: Vec<&[u8]>, + fixedsize_values: Vec<&[u8; 3]>, +) -> RecordBatch { + let num_rows = string_values.len(); + let name: StringArray = std::iter::repeat(Some(name)).take(num_rows).collect(); + let service_string: StringArray = string_values.iter().map(Some).collect(); + let service_binary: BinaryArray = binary_values.iter().map(Some).collect(); + let service_fixedsize: FixedSizeBinaryArray = fixedsize_values + .iter() + .map(|value| Some(value.as_slice())) + .collect::>() + .into(); + + let schema = Schema::new(vec![ + Field::new("name", name.data_type().clone(), true), + // note the column name has a period in it! + Field::new("service_string", service_string.data_type().clone(), true), + Field::new("service_binary", service_binary.data_type().clone(), true), + Field::new( + "service_fixedsize", + service_fixedsize.data_type().clone(), + true, + ), + ]); + let schema = Arc::new(schema); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(name), + Arc::new(service_string), + Arc::new(service_binary), + Arc::new(service_fixedsize), + ], + ) + .unwrap() +} + /// returns a batch with two columns (note "service.name" is the name /// of the column. It is *not* a table named service.name /// @@ -604,6 +650,66 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![100000, 200000, 300000, 400000, 600000], 38, 5), ] } + Scenario::ByteArray => { + // frontends first, then backends. All in order, except frontends 4 and 7 + // are swapped to cause a statistics false positive on the 'fixed size' column. + vec![ + make_bytearray_batch( + "all frontends", + vec![ + "frontend one", + "frontend two", + "frontend three", + "frontend seven", + "frontend five", + ], + vec![ + b"frontend one", + b"frontend two", + b"frontend three", + b"frontend seven", + b"frontend five", + ], + vec![b"fe1", b"fe2", b"fe3", b"fe7", b"fe5"], + ), + make_bytearray_batch( + "mixed", + vec![ + "frontend six", + "frontend four", + "backend one", + "backend two", + "backend three", + ], + vec![ + b"frontend six", + b"frontend four", + b"backend one", + b"backend two", + b"backend three", + ], + vec![b"fe6", b"fe4", b"be1", b"be2", b"be3"], + ), + make_bytearray_batch( + "all backends", + vec![ + "backend four", + "backend five", + "backend six", + "backend seven", + "backend eight", + ], + vec![ + b"backend four", + b"backend five", + b"backend six", + b"backend seven", + b"backend eight", + ], + vec![b"be4", b"be5", b"be6", b"be7", b"be8"], + ), + ] + } Scenario::PeriodsInColumnNames => { vec![ // all frontend diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index b7038ef1a73f..406eb721bf94 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -744,6 +744,108 @@ async fn prune_decimal_in_list() { .await; } +#[tokio::test] +async fn prune_string_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'backend one'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'backend nine'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'frontend nine'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'frontend five' < 'frontend nine' < 'frontend two' + // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string != 'backend one'", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string < 'backend one'", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string < 'backend zero'", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + #[tokio::test] async fn prune_periods_in_column_names() { // There are three row groups for "service.name", each with 5 rows = 15 rows total From 3c3b22866a7ece784208e9d499119b2e13399762 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 19 Mar 2024 11:38:25 -0400 Subject: [PATCH 14/35] Fix incorrect results with multiple `COUNT(DISTINCT..)` aggregates on dictionaries (#9679) * Add test for multiple count distincts on a dictionary * Fix accumulator merge bug * Fix cleanup code --- datafusion/common/src/scalar/mod.rs | 2 +- .../src/aggregate/count_distinct/mod.rs | 32 +++++++-- .../sqllogictest/test_files/dictionary.slt | 67 +++++++++++++++++++ 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5ace44f24b69..316624175e1c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1746,7 +1746,7 @@ impl ScalarValue { } /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a [`ListArray`]. + /// `data_type`, to a single element [`ListArray`]. /// /// Example /// ``` diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 71782fcc5f9b..fb5e7710496c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -47,7 +47,7 @@ use crate::binary_map::OutputType; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -/// Expression for a COUNT(DISTINCT) aggregation. +/// Expression for a `COUNT(DISTINCT)` aggregation. #[derive(Debug)] pub struct DistinctCount { /// Column name @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount { use TimeUnit::*; Ok(match &self.state_data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount { OutputType::Binary, )), + // Use the generic accumulator based on `ScalarValue` for all other types _ => Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -183,7 +185,11 @@ impl PartialEq for DistinctCount { } /// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. Some types have specialized accumulators that are (much) +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and /// [`BytesDistinctCountAccumulator`] #[derive(Debug)] @@ -193,8 +199,9 @@ struct DistinctCountAccumulator { } impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -207,7 +214,8 @@ impl DistinctCountAccumulator { + std::mem::size_of::() } - // calculates the size as accurate as possible, call to this method is expensive + // calculates the size as accurately as possible. Note that calling this + // method is expensive fn full_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -221,6 +229,7 @@ impl DistinctCountAccumulator { } impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); @@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator { }) } + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); @@ -253,8 +267,12 @@ impl Accumulator for DistinctCountAccumulator { assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let array = &states[0]; let list_array = array.as_list::(); - let inner_array = list_array.value(0); - self.update_batch(&[inner_array]) + for inner_array in list_array.iter() { + let inner_array = inner_array + .expect("counts are always non null, so are intermediate results"); + self.update_batch(&[inner_array])?; + } + Ok(()) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 002aade2528e..af7bf5cb16e8 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -280,3 +280,70 @@ ORDER BY 2023-12-20T01:20:00 1000 f2 foo 2023-12-20T01:30:00 1000 f1 32.0 2023-12-20T01:30:00 1000 f2 foo + +# Cleanup +statement ok +drop view m1; + +statement ok +drop view m2; + +###### +# Create a table using UNION ALL to get 2 partitions (very important) +###### +statement ok +create table m3_source as + select * from (values('foo', 'bar', 1)) + UNION ALL + select * from (values('foo', 'baz', 1)); + +###### +# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code +###### +statement ok +create table m3 as + select + column1, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2", + column3 +from m3_source; + +# there are two values in column2 +query T?I rowsort +SELECT * +FROM m3; +---- +foo bar 1 +foo baz 1 + +# There is 1 distinct value in column1 +query I +SELECT count(distinct column1) +FROM m3 +GROUP BY column3; +---- +1 + +# There are 2 distinct values in column2 +query I +SELECT count(distinct column2) +FROM m3 +GROUP BY column3; +---- +2 + +# Should still get the same results when querying in the same query +query II +SELECT count(distinct column1), count(distinct column2) +FROM m3 +GROUP BY column3; +---- +1 2 + + +# Cleanup +statement ok +drop table m3; + +statement ok +drop table m3_source; From b87dd6143c2dc089b07f74780bd525c4369e68a3 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Tue, 19 Mar 2024 16:43:11 +0100 Subject: [PATCH 15/35] Add support for Bloom filters on binary columns (#9644) --- .../physical_plan/parquet/row_groups.rs | 1 + .../core/tests/parquet/row_group_pruning.rs | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index a0bb5ab71204..9cd46994960f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -225,6 +225,7 @@ impl PruningStatistics for BloomFilterStatistics { .map(|value| { match value { ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) => sbbf.check(v), ScalarValue::Boolean(Some(v)) => sbbf.check(v), ScalarValue::Float64(Some(v)) => sbbf.check(v), ScalarValue::Float32(Some(v)) => sbbf.check(v), diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 406eb721bf94..55112193502d 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -846,6 +846,108 @@ async fn prune_string_lt() { .await; } +#[tokio::test] +async fn prune_binary_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('backend nine' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('frontend nine' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'frontend five' < 'frontend nine' < 'frontend two' + // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary != CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary < CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary < CAST('backend zero' AS bytea)", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + #[tokio::test] async fn prune_periods_in_column_names() { // There are three row groups for "service.name", each with 5 rows = 15 rows total From 7af69f9768497060343ae2a6fbd1991e9a047dce Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 20 Mar 2024 05:22:12 +1300 Subject: [PATCH 16/35] Update Arrow/Parquet to `51.0.0`, tonic to `0.11` (#9613) * Prepare for arrow 51 * Fix datafusion-proto * Update deserialize_to_struct * Format * Update pins * Update datafusion-cli Cargo.lock * Remove stale comment * Add comment to seconds --------- Co-authored-by: Andrew Lamb --- Cargo.toml | 18 +- datafusion-cli/Cargo.lock | 88 +++---- datafusion-cli/Cargo.toml | 4 +- datafusion-examples/Cargo.toml | 3 +- .../examples/deserialize_to_struct.rs | 58 ++--- .../examples/flight/flight_server.rs | 9 +- .../examples/flight/flight_sql_server.rs | 3 + .../common/src/file_options/parquet_writer.rs | 1 + datafusion/common/src/scalar/mod.rs | 10 +- .../src/datasource/avro_to_arrow/schema.rs | 6 + .../src/datasource/file_format/parquet.rs | 4 - .../datasource/physical_plan/parquet/mod.rs | 8 +- .../functions/src/datetime/date_part.rs | 214 ++++++------------ datafusion/proto/src/logical_plan/to_proto.rs | 3 + datafusion/sql/src/unparser/expr.rs | 4 + 15 files changed, 195 insertions(+), 238 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 48e555bd5527..d9e69e53db7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,14 +57,14 @@ version = "36.0.0" # for the inherited dependency but cannot do the reverse (override from true to false). # # See for more detaiils: https://github.com/rust-lang/cargo/issues/11329 -arrow = { version = "50.0.0", features = ["prettyprint"] } -arrow-array = { version = "50.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "50.0.0", default-features = false } -arrow-flight = { version = "50.0.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "50.0.0", default-features = false, features = ["lz4"] } -arrow-ord = { version = "50.0.0", default-features = false } -arrow-schema = { version = "50.0.0", default-features = false } -arrow-string = { version = "50.0.0", default-features = false } +arrow = { version = "51.0.0", features = ["prettyprint"] } +arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "51.0.0", default-features = false } +arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "51.0.0", default-features = false } +arrow-schema = { version = "51.0.0", default-features = false } +arrow-string = { version = "51.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -95,7 +95,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.9.0", default-features = false } parking_lot = "0.12" -parquet = { version = "50.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8e2a2c353e2d..51cccf60a1e4 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa285343fba4d829d49985bdc541e3789cf6000ed0e84be7c039438df4a4e78c" +checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" dependencies = [ "arrow-arith", "arrow-array", @@ -151,9 +151,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "753abd0a5290c1bcade7c6623a556f7d1659c5f4148b140b5b63ce7bd1a45705" +checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" dependencies = [ "arrow-array", "arrow-buffer", @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d390feeb7f21b78ec997a4081a025baef1e2e0d6069e181939b61864c9779609" +checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" dependencies = [ "ahash", "arrow-buffer", @@ -183,9 +183,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69615b061701bcdffbc62756bc7e85c827d5290b472b580c972ebbbf690f5aa4" +checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" dependencies = [ "bytes", "half", @@ -194,28 +194,30 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e448e5dd2f4113bf5b74a1f26531708f5edcacc77335b7066f9398f4bcf4cdef" +checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "base64 0.21.7", + "atoi", + "base64 0.22.0", "chrono", "comfy-table", "half", "lexical-core", "num", + "ryu", ] [[package]] name = "arrow-csv" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46af72211f0712612f5b18325530b9ad1bfbdc87290d5fbfd32a7da128983781" +checksum = "95cbcba196b862270bf2a5edb75927380a7f3a163622c61d40cbba416a6305f2" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,9 +234,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67d644b91a162f3ad3135ce1184d0a31c28b816a581e08f29e8e9277a574c64e" +checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -244,9 +246,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03dea5e79b48de6c2e04f03f62b0afea7105be7b77d134f6c5414868feefb80d" +checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" dependencies = [ "arrow-array", "arrow-buffer", @@ -259,9 +261,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8950719280397a47d37ac01492e3506a8a724b3fb81001900b866637a829ee0f" +checksum = "eaafb5714d4e59feae964714d724f880511500e3569cc2a94d02456b403a2a49" dependencies = [ "arrow-array", "arrow-buffer", @@ -279,9 +281,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed9630979034077982d8e74a942b7ac228f33dd93a93b615b4d02ad60c260be" +checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" dependencies = [ "arrow-array", "arrow-buffer", @@ -294,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "007035e17ae09c4e8993e4cb8b5b96edf0afb927cd38e2dff27189b274d83dcf" +checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" dependencies = [ "ahash", "arrow-array", @@ -309,15 +311,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff3e9c01f7cd169379d269f926892d0e622a704960350d09d331be3ec9e0029" +checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" [[package]] name = "arrow-select" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce20973c1912de6514348e064829e50947e35977bb9d7fb637dc99ea9ffd78c" +checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" dependencies = [ "ahash", "arrow-array", @@ -329,15 +331,16 @@ dependencies = [ [[package]] name = "arrow-string" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f3b37f2aeece31a2636d1b037dabb69ef590e03bdc7eb68519b51ec86932a7" +checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "memchr", "num", "regex", "regex-syntax", @@ -387,6 +390,15 @@ dependencies = [ "syn 2.0.53", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atty" version = "0.2.14" @@ -739,9 +751,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "blake2" @@ -2128,7 +2140,7 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "libc", "redox_syscall", ] @@ -2442,9 +2454,9 @@ dependencies = [ [[package]] name = "parquet" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "547b92ebf0c1177e3892f44c8f79757ee62e678d564a9834189725f2c5b7a750" +checksum = "096795d4f47f65fd3ee1ec5a98b77ab26d602f2cc785b0e4be5443add17ecc32" dependencies = [ "ahash", "arrow-array", @@ -2454,7 +2466,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "arrow-select", - "base64 0.21.7", + "base64 0.22.0", "brotli", "bytes", "chrono", @@ -2903,7 +2915,7 @@ version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", @@ -3720,9 +3732,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", "serde", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index ad506762f0d0..da744a06f3aa 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,7 +30,7 @@ rust-version = "1.72" readme = "README.md" [dependencies] -arrow = "50.0.0" +arrow = "51.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" @@ -52,7 +52,7 @@ futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.9.0", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "50.0.0", default-features = false } +parquet = { version = "51.0.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index ad2a49fb352e..2b6e869ec500 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -74,7 +74,6 @@ serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -# 0.10 and 0.11 are incompatible. Need to upgrade tonic to 0.11 when upgrading to arrow 51 -tonic = "0.10" +tonic = "0.11" url = { workspace = true } uuid = "1.2" diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs index e999fc4dac3e..985cab703a5c 100644 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ b/datafusion-examples/examples/deserialize_to_struct.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::AsArray; +use arrow::datatypes::{Float64Type, Int32Type}; use datafusion::error::Result; use datafusion::prelude::*; -use serde::Deserialize; +use futures::StreamExt; /// This example shows that it is possible to convert query results into Rust structs . -/// It will collect the query results into RecordBatch, then convert it to serde_json::Value. -/// Then, serde_json::Value is turned into Rust's struct. -/// Any datatype with `Deserialize` implemeneted works. #[tokio::main] async fn main() -> Result<()> { let data_list = Data::new().await?; @@ -30,10 +29,10 @@ async fn main() -> Result<()> { Ok(()) } -#[derive(Deserialize, Debug)] +#[derive(Debug)] struct Data { #[allow(dead_code)] - int_col: i64, + int_col: i32, #[allow(dead_code)] double_col: f64, } @@ -41,35 +40,36 @@ struct Data { impl Data { pub async fn new() -> Result> { // this group is almost the same as the one you find it in parquet_sql.rs - let batches = { - let ctx = SessionContext::new(); + let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; - let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") - .await?; + let df = ctx + .sql("SELECT int_col, double_col FROM alltypes_plain") + .await?; - df.clone().show().await?; + df.clone().show().await?; - df.collect().await? - }; - let batches: Vec<_> = batches.iter().collect(); + let mut stream = df.execute_stream().await?; + let mut list = vec![]; + while let Some(b) = stream.next().await.transpose()? { + let int_col = b.column(0).as_primitive::(); + let float_col = b.column(1).as_primitive::(); - // converts it to serde_json type and then convert that into Rust type - let list = arrow::json::writer::record_batches_to_json_rows(&batches[..])? - .into_iter() - .map(|val| serde_json::from_value(serde_json::Value::Object(val))) - .take_while(|val| val.is_ok()) - .map(|val| val.unwrap()) - .collect(); + for (i, f) in int_col.values().iter().zip(float_col.values()) { + list.push(Data { + int_col: *i, + double_col: *f, + }) + } + } Ok(list) } diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index cb7b7c28d909..f9d1b8029f04 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -18,7 +18,7 @@ use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; use std::sync::Arc; -use arrow_flight::SchemaAsIpc; +use arrow_flight::{PollInfo, SchemaAsIpc}; use datafusion::arrow::error::ArrowError; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; @@ -177,6 +177,13 @@ impl FlightService for FlightServiceImpl { ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } + + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } } fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 35d475623062..ed9457643b7d 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -307,6 +307,8 @@ impl FlightSqlService for FlightSqlServiceImpl { let endpoint = FlightEndpoint { ticket: Some(ticket), location: vec![], + expiration_time: None, + app_metadata: Default::default(), }; let endpoints = vec![endpoint]; @@ -329,6 +331,7 @@ impl FlightSqlService for FlightSqlServiceImpl { total_records: -1_i64, total_bytes: -1_i64, ordered: false, + app_metadata: Default::default(), }; let resp = Response::new(info); Ok(resp) diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index e8a350e8d389..28e73ba48f53 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -156,6 +156,7 @@ pub(crate) fn parse_encoding_string( "plain" => Ok(parquet::basic::Encoding::PLAIN), "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), "rle" => Ok(parquet::basic::Encoding::RLE), + #[allow(deprecated)] "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), "delta_length_byte_array" => { diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 316624175e1c..a2484e93e812 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1650,7 +1650,11 @@ impl ScalarValue { | DataType::Duration(_) | DataType::Union(_, _) | DataType::Map(_, _) - | DataType::RunEndEncoded(_, _) => { + | DataType::RunEndEncoded(_, _) + | DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { return _internal_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, @@ -5769,7 +5773,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); #[rustfmt::skip] - let expected = [ + let expected = [ "+---+", "| s |", "+---+", @@ -5803,7 +5807,7 @@ mod tests { &DataType::List(Arc::new(Field::new( "item", DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), - true + true, ))) ); } diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index 761e6b62680f..039a6aacc07e 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -224,6 +224,12 @@ fn default_field_name(dt: &DataType) -> &str { DataType::RunEndEncoded(_, _) => { unimplemented!("RunEndEncoded support not implemented") } + DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { + unimplemented!("View support not implemented") + } DataType::Decimal128(_, _) => "decimal", DataType::Decimal256(_, _) => "decimal", } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index c04c536e7ca6..b7626d41f4dd 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -78,9 +78,6 @@ use hashbrown::HashMap; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -/// Size of the buffer for [`AsyncArrowWriter`]. -const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; - /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. const INITIAL_BUFFER_BYTES: usize = 1048576; @@ -626,7 +623,6 @@ impl ParquetSink { let writer = AsyncArrowWriter::try_new( multipart_writer, self.get_writer_schema(), - PARQUET_WRITER_BUFFER_SIZE, Some(parquet_props), )?; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index a2e645cf3e72..282cd624d036 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -701,12 +701,8 @@ pub async fn plan_to_parquet( let (_, multipart_writer) = storeref.put_multipart(&file).await?; let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let mut writer = AsyncArrowWriter::try_new( - multipart_writer, - plan.schema(), - 10485760, - propclone, - )?; + let mut writer = + AsyncArrowWriter::try_new(multipart_writer, plan.schema(), propclone)?; while let Some(next_batch) = stream.next().await { let batch = next_batch?; writer.write(&batch).await?; diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 1f00f5bc3137..5d2719bf0365 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -18,16 +18,14 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::types::ArrowTemporalType; -use arrow::array::{Array, ArrayRef, ArrowNumericType, Float64Array, PrimitiveArray}; -use arrow::compute::cast; -use arrow::compute::kernels::temporal; +use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::compute::{binary, cast, date_part, DatePart}; use arrow::datatypes::DataType::{Date32, Date64, Float64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_timestamp_microsecond_array, + as_date32_array, as_date64_array, as_int32_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -78,46 +76,6 @@ impl DatePartFunc { } } -macro_rules! extract_date_part { - ($ARRAY: expr, $FN:expr) => { - match $ARRAY.data_type() { - DataType::Date32 => { - let array = as_date32_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - DataType::Date64 => { - let array = as_date64_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => { - let array = as_timestamp_second_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Millisecond => { - let array = as_timestamp_millisecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Microsecond => { - let array = as_timestamp_microsecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Nanosecond => { - let array = as_timestamp_nanosecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - }, - datatype => exec_err!("Extract does not support datatype {:?}", datatype), - } - }; -} - impl ScalarUDFImpl for DatePartFunc { fn as_any(&self) -> &dyn Any { self @@ -139,16 +97,15 @@ impl ScalarUDFImpl for DatePartFunc { if args.len() != 2 { return exec_err!("Expected two arguments in DATE_PART"); } - let (date_part, array) = (&args[0], &args[1]); + let (part, array) = (&args[0], &args[1]); - let date_part = - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = date_part { - v - } else { - return exec_err!( - "First argument of `DATE_PART` must be non-null scalar Utf8" - ); - }; + let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { + v + } else { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ); + }; let is_scalar = matches!(array, ColumnarValue::Scalar(_)); @@ -157,28 +114,28 @@ impl ScalarUDFImpl for DatePartFunc { ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; - let arr = match date_part.to_lowercase().as_str() { - "year" => extract_date_part!(&array, temporal::year), - "quarter" => extract_date_part!(&array, temporal::quarter), - "month" => extract_date_part!(&array, temporal::month), - "week" => extract_date_part!(&array, temporal::week), - "day" => extract_date_part!(&array, temporal::day), - "doy" => extract_date_part!(&array, temporal::doy), - "dow" => extract_date_part!(&array, temporal::num_days_from_sunday), - "hour" => extract_date_part!(&array, temporal::hour), - "minute" => extract_date_part!(&array, temporal::minute), - "second" => extract_date_part!(&array, seconds), - "millisecond" => extract_date_part!(&array, millis), - "microsecond" => extract_date_part!(&array, micros), - "nanosecond" => extract_date_part!(&array, nanos), - "epoch" => extract_date_part!(&array, epoch), - _ => exec_err!("Date part '{date_part}' not supported"), - }?; + let arr = match part.to_lowercase().as_str() { + "year" => date_part_f64(array.as_ref(), DatePart::Year)?, + "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, + "month" => date_part_f64(array.as_ref(), DatePart::Month)?, + "week" => date_part_f64(array.as_ref(), DatePart::Week)?, + "day" => date_part_f64(array.as_ref(), DatePart::Day)?, + "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, + "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "hour" => date_part_f64(array.as_ref(), DatePart::Hour)?, + "minute" => date_part_f64(array.as_ref(), DatePart::Minute)?, + "second" => seconds(array.as_ref(), Second)?, + "millisecond" => seconds(array.as_ref(), Millisecond)?, + "microsecond" => seconds(array.as_ref(), Microsecond)?, + "nanosecond" => seconds(array.as_ref(), Nanosecond)?, + "epoch" => epoch(array.as_ref())?, + _ => return exec_err!("Date part '{part}' not supported"), + }; Ok(if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(&arr?, 0)?) + ColumnarValue::Scalar(ScalarValue::try_from_array(arr.as_ref(), 0)?) } else { - ColumnarValue::Array(arr?) + ColumnarValue::Array(arr) }) } @@ -187,83 +144,52 @@ impl ScalarUDFImpl for DatePartFunc { } } -fn to_ticks(array: &PrimitiveArray, frac: i32) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let zipped = temporal::second(array)? - .values() - .iter() - .zip(temporal::nanosecond(array)?.values().iter()) - .map(|o| (*o.0 as f64 + (*o.1 as f64) / 1_000_000_000.0) * (frac as f64)) - .collect::>(); - - Ok(Float64Array::from(zipped)) +/// Invoke [`date_part`] and cast the result to Float64 +fn date_part_f64(array: &dyn Array, part: DatePart) -> Result { + Ok(cast(date_part(array, part)?.as_ref(), &Float64)?) } -fn seconds(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1) -} - -fn millis(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000) -} - -fn micros(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000_000) +/// invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// result to a total number of seconds, milliseconds, microseconds or +/// nanoseconds +/// +/// # Panics +/// If `array` is not a temporal type such as Timestamp or Date32 +fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { + let sf = match unit { + Second => 1_f64, + Millisecond => 1_000_f64, + Microsecond => 1_000_000_f64, + Nanosecond => 1_000_000_000_f64, + }; + let secs = date_part(array, DatePart::Second)?; + let secs = as_int32_array(secs.as_ref())?; + let subsecs = date_part(array, DatePart::Nanosecond)?; + let subsecs = as_int32_array(subsecs.as_ref())?; + + let r: Float64Array = binary(secs, subsecs, |secs, subsecs| { + (secs as f64 + (subsecs as f64 / 1_000_000_000_f64)) * sf + })?; + Ok(Arc::new(r)) } -fn nanos(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000_000_000) -} +fn epoch(array: &dyn Array) -> Result { + const SECONDS_IN_A_DAY: f64 = 86400_f64; -fn epoch(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let b = match array.data_type() { - Timestamp(tu, _) => { - let scale = match tu { - Second => 1, - Millisecond => 1_000, - Microsecond => 1_000_000, - Nanosecond => 1_000_000_000, - } as f64; - array.unary(|n| { - let n: i64 = n.into(); - n as f64 / scale - }) + let f: Float64Array = match array.data_type() { + Timestamp(Second, _) => as_timestamp_second_array(array)?.unary(|x| x as f64), + Timestamp(Millisecond, _) => { + as_timestamp_millisecond_array(array)?.unary(|x| x as f64 / 1_000_f64) + } + Timestamp(Microsecond, _) => { + as_timestamp_microsecond_array(array)?.unary(|x| x as f64 / 1_000_000_f64) } - Date32 => { - let seconds_in_a_day = 86400_f64; - array.unary(|n| { - let n: i64 = n.into(); - n as f64 * seconds_in_a_day - }) + Timestamp(Nanosecond, _) => { + as_timestamp_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - Date64 => array.unary(|n| { - let n: i64 = n.into(); - n as f64 / 1_000_f64 - }), - _ => return exec_err!("Can not convert {:?} to epoch", array.data_type()), + Date32 => as_date32_array(array)?.unary(|x| x as f64 * SECONDS_IN_A_DAY), + Date64 => as_date64_array(array)?.unary(|x| x as f64 / 1_000_f64), + d => return exec_err!("Can not convert {d:?} to epoch"), }; - Ok(b) + Ok(Arc::new(f)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 478f7c779552..92015594906b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -230,6 +230,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() )) } + DataType::Utf8View | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { + return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + } }; Ok(res) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9680177d736f..c26e8481ce43 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -491,11 +491,15 @@ impl Unparser<'_> { DataType::Binary => todo!(), DataType::FixedSizeBinary(_) => todo!(), DataType::LargeBinary => todo!(), + DataType::BinaryView => todo!(), DataType::Utf8 => Ok(ast::DataType::Varchar(None)), DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8View => todo!(), DataType::List(_) => todo!(), DataType::FixedSizeList(_, _) => todo!(), DataType::LargeList(_) => todo!(), + DataType::ListView(_) => todo!(), + DataType::LargeListView(_) => todo!(), DataType::Struct(_) => todo!(), DataType::Union(_, _) => todo!(), DataType::Dictionary(_, _) => todo!(), From 7fab5ac53c1e715743aee7a51111c2976add8a99 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 20 Mar 2024 00:58:10 +0800 Subject: [PATCH 17/35] Move inlist rule to expr_simplifier (#9692) * move inlist rule to expr_simplifier Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../simplify_expressions/expr_simplifier.rs | 220 +++++++++++++++++- .../simplify_expressions/inlist_simplifier.rs | 122 +--------- .../sqllogictest/test_files/predicates.slt | 2 +- 3 files changed, 210 insertions(+), 134 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5b5bca75ddb0..61e002ece98b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,7 +21,7 @@ use std::borrow::Cow; use std::collections::HashSet; use std::ops::Not; -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -175,7 +175,6 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); - let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { @@ -190,8 +189,6 @@ impl ExprSimplifier { .data()? .rewrite(&mut simplifier) .data()? - .rewrite(&mut inlist_simplifier) - .data()? .rewrite(&mut guarantee_rewriter) .data()? // run both passes twice to try an minimize simplifications that we missed @@ -1452,13 +1449,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Operator::Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { - let left = as_inlist(left.as_ref()); - let right = as_inlist(right.as_ref()); - - let lhs = left.unwrap(); - let rhs = right.unwrap(); - let lhs = lhs.into_owned(); - let rhs = rhs.into_owned(); + let lhs = to_inlist(*left).unwrap(); + let rhs = to_inlist(*right).unwrap(); let mut seen: HashSet = HashSet::new(); let list = lhs .list @@ -1473,7 +1465,123 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { negated: false, }; - return Ok(Transformed::yes(Expr::InList(merged_inlist))); + Transformed::yes(Expr::InList(merged_inlist)) + } + + // Simplify expressions that is guaranteed to be true or false to a literal boolean expression + // + // Rules: + // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists + // Intersection: + // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` + // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` + // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` + // Union: + // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` + // # This rule is handled by `or_in_list_simplifier.rs` + // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` + // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression + // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` + // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` + // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, false).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_union(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l1, l2).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l2, l1).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } } // no additional rewrites possible @@ -1482,6 +1590,22 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 +fn are_inlist_and_eq_and_match_neg( + left: &Expr, + right: &Expr, + is_left_neg: bool, + is_right_neg: bool, +) -> bool { + match (left, right) { + (Expr::InList(l), Expr::InList(r)) => { + l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg + } + _ => false, + } +} + +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); @@ -1519,6 +1643,78 @@ fn as_inlist(expr: &Expr) -> Option> { } } +fn to_inlist(expr: Expr) -> Option { + match expr { + Expr::InList(inlist) => Some(inlist), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(_)) => Some(InList { + expr: left, + list: vec![*right], + negated: false, + }), + (Expr::Literal(_), Expr::Column(_)) => Some(InList { + expr: right, + list: vec![*left], + negated: false, + }), + _ => None, + }, + _ => None, + } +} + +/// Return the union of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { + // extend the list in l1 with the elements in l2 that are not already in l1 + let l1_items: HashSet<_> = l1.list.iter().collect(); + + // keep all l2 items that do not also appear in l1 + let keep_l2: Vec<_> = l2 + .list + .into_iter() + .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) + .collect(); + + l1.list.extend(keep_l2); + l1.negated = negated; + Ok(Expr::InList(l1)) +} + +/// Return the intersection of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // remove all items from l1 that are not in l2 + l1.list.retain(|e| l2_items.contains(e)); + + // e in () is always false + // e not in () is always true + if l1.list.is_empty() { + return Ok(lit(negated)); + } + Ok(Expr::InList(l1)) +} + +/// Return the all items in l1 that are not in l2 +/// maintaining the order of the elements in the two lists +fn inlist_except(mut l1: InList, l2: InList) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // keep only items from l1 that are not in l2 + l1.list.retain(|e| !l2_items.contains(e)); + + if l1.list.is_empty() { + return Ok(lit(false)); + } + Ok(Expr::InList(l1)) +} + #[cfg(test)] mod tests { use std::{ diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 5d1cf27827a9..9dcb8ed15563 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,12 +19,10 @@ use super::THRESHOLD_INLINE_INLIST; -use std::collections::HashSet; - use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; -use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; +use datafusion_expr::Expr; pub(super) struct ShortenInListSimplifier {} @@ -97,121 +95,3 @@ impl TreeNodeRewriter for ShortenInListSimplifier { Ok(Transformed::no(expr)) } } - -pub(super) struct InListSimplifier {} - -impl InListSimplifier { - pub(super) fn new() -> Self { - Self {} - } -} - -impl TreeNodeRewriter for InListSimplifier { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - // Simplify expressions that is guaranteed to be true or false to a literal boolean expression - // - // Rules: - // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists - // Intersection: - // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` - // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` - // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` - // Union: - // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` - // # This rule is handled by `or_in_list_simplifier.rs` - // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` - // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression - // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` - // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` - // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr.clone() { - match (*left, op, *right) { - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && !l2.negated => - { - return inlist_intersection(l1, l2, false).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_union(l1, l2, true).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && l2.negated => - { - return inlist_except(l1, l2).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && !l2.negated => - { - return inlist_except(l2, l1).map(Transformed::yes); - } - (Expr::InList(l1), Operator::Or, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_intersection(l1, l2, true).map(Transformed::yes); - } - (left, op, right) => { - // put the expression back together - return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), - }))); - } - } - } - - Ok(Transformed::no(expr)) - } -} - -/// Return the union of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { - // extend the list in l1 with the elements in l2 that are not already in l1 - let l1_items: HashSet<_> = l1.list.iter().collect(); - - // keep all l2 items that do not also appear in l1 - let keep_l2: Vec<_> = l2 - .list - .into_iter() - .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) - .collect(); - - l1.list.extend(keep_l2); - l1.negated = negated; - Ok(Expr::InList(l1)) -} - -/// Return the intersection of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // remove all items from l1 that are not in l2 - l1.list.retain(|e| l2_items.contains(e)); - - // e in () is always false - // e not in () is always true - if l1.list.is_empty() { - return Ok(lit(negated)); - } - Ok(Expr::InList(l1)) -} - -/// Return the all items in l1 that are not in l2 -/// maintaining the order of the elements in the two lists -fn inlist_except(mut l1: InList, l2: InList) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // keep only items from l1 that are not in l2 - l1.list.retain(|e| !l2_items.contains(e)); - - if l1.list.is_empty() { - return Ok(lit(false)); - } - Ok(Expr::InList(l1)) -} diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 4c9254beef6b..33c9ff7c3eed 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -781,4 +781,4 @@ logical_plan EmptyRelation physical_plan EmptyExec statement ok -drop table t; +drop table t; \ No newline at end of file From 09747596fd75bfce8903e86472cccb8acc524453 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 19 Mar 2024 11:49:36 -0600 Subject: [PATCH 18/35] Support Serde for ScalarUDF in Physical Expressions (#9436) * initial try * revert * stage commit * use ScalarFunctionDefinition to rewrite PhysicalExpr proto * cargo fmt * feat : add test * fix bug * fix wrong delete code when resolve conflict * Update datafusion/proto/src/physical_plan/to_proto.rs Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> * Update datafusion/proto/tests/cases/roundtrip_physical_plan.rs Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> * address the comment --------- Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> --- .../physical_optimizer/projection_pushdown.rs | 58 ++- datafusion/physical-expr/src/functions.rs | 10 +- .../physical-expr/src/scalar_function.rs | 26 +- datafusion/physical-expr/src/udf.rs | 7 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 21 + datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/physical_plan/from_proto.rs | 44 +- datafusion/proto/src/physical_plan/mod.rs | 139 ++++-- .../proto/src/physical_plan/to_proto.rs | 433 ++++++++++-------- .../tests/cases/roundtrip_physical_plan.rs | 157 ++++++- 11 files changed, 634 insertions(+), 264 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index ab5611597472..ed445e6d48b8 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1287,6 +1287,7 @@ fn new_join_children( #[cfg(test)] mod tests { use super::*; + use std::any::Any; use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -1313,7 +1314,10 @@ mod tests { use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_expr::{ + ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, + }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -1329,6 +1333,42 @@ mod tests { use itertools::Itertools; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } + #[test] fn test_update_matching_exprs() -> Result<()> { let exprs: Vec> = vec![ @@ -1345,7 +1385,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1412,7 +1454,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1482,7 +1526,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1549,7 +1595,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c6c185e002f0..e76e7f56dc95 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -44,6 +44,7 @@ use arrow_array::Array; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; +use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, @@ -57,7 +58,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - execution_props: &ExecutionProps, + _execution_props: &ExecutionProps, ) -> Result> { let input_expr_types = input_phy_exprs .iter() @@ -69,14 +70,12 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = - create_physical_fun(fun, execution_props)?; - let monotonicity = fun.monotonicity(); + let fun_def = ScalarFunctionDefinition::BuiltIn(*fun); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), - fun_expr, + fun_def, input_phy_exprs.to_vec(), data_type, monotonicity, @@ -195,7 +194,6 @@ where /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - _execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1c9f0e609c3c..d34084236690 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,22 +34,22 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::functions::out_ordering; +use crate::functions::{create_physical_fun, out_ordering}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, - ScalarFunctionImplementation, + ScalarFunctionDefinition, }; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, name: String, args: Vec>, return_type: DataType, @@ -79,7 +79,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, args: Vec>, return_type: DataType, monotonicity: Option, @@ -96,7 +96,7 @@ impl ScalarFunctionExpr { } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionImplementation { + pub fn fun(&self) -> &ScalarFunctionDefinition { &self.fun } @@ -172,8 +172,18 @@ impl PhysicalExpr for ScalarFunctionExpr { }; // evaluate the function - let fun = self.fun.as_ref(); - (fun)(&inputs) + match self.fun { + ScalarFunctionDefinition::BuiltIn(ref fun) => { + let fun = create_physical_fun(fun)?; + (fun)(&inputs) + } + ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), + ScalarFunctionDefinition::Name(_) => { + internal_err!( + "Name function must be resolved to one of the other variants prior to physical planning" + ) + } + } } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index ede3e5badbb1..4fc94bfa15ec 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -20,7 +20,9 @@ use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow_schema::Schema; use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; -use datafusion_expr::{type_coercion::functions::data_types, Expr}; +use datafusion_expr::{ + type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, +}; use std::sync::Arc; /// Create a physical expression of the UDF. @@ -45,9 +47,10 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun(), + fun_def, input_phy_exprs.to_vec(), return_type, fun.monotonicity()?, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6879f70cd05c..7a9b427ce7d3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1458,6 +1458,7 @@ message PhysicalExprNode { message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; + optional bytes fun_definition = 3; ArrowType return_type = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 75c135fd01b4..fd27520b3be0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20391,6 +20391,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.return_type.is_some() { len += 1; } @@ -20401,6 +20404,10 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.return_type.as_ref() { struct_ser.serialize_field("returnType", v)?; } @@ -20416,6 +20423,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { const FIELDS: &[&str] = &[ "name", "args", + "fun_definition", + "funDefinition", "return_type", "returnType", ]; @@ -20424,6 +20433,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { enum GeneratedField { Name, Args, + FunDefinition, ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -20448,6 +20458,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { match value { "name" => Ok(GeneratedField::Name), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -20470,6 +20481,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { let mut name__ = None; let mut args__ = None; + let mut fun_definition__ = None; let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -20485,6 +20497,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); @@ -20496,6 +20516,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { Ok(PhysicalScalarUdfNode { name: name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, return_type: return_type__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c9cc4a9b073b..16ad2b848db9 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2092,6 +2092,8 @@ pub struct PhysicalScalarUdfNode { pub name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(message, optional, tag = "4")] pub return_type: ::core::option::Option, } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 184c048c1bdd..ca54d4e803ca 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -59,9 +59,12 @@ use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; use chrono::{TimeZone, Utc}; +use datafusion_expr::ScalarFunctionDefinition; use object_store::path::Path; use object_store::ObjectMeta; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { Column::new(&c.name, c.index as usize) @@ -82,7 +85,8 @@ pub fn parse_physical_sort_expr( input_schema: &Schema, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -110,7 +114,9 @@ pub fn parse_physical_sort_exprs( .iter() .map(|sort_expr| { if let Some(expr) = &sort_expr.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = + parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -137,16 +143,17 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { + let codec = DefaultPhysicalExtensionCodec {}; let window_node_expr = proto .args .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>>()?; let partition_by = proto .partition_by .iter() - .map(|p| parse_physical_expr(p, registry, input_schema)) + .map(|p| parse_physical_expr(p, registry, input_schema, &codec)) .collect::>>()?; let order_by = proto @@ -191,6 +198,7 @@ pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { let expr_type = proto .expr_type @@ -270,7 +278,7 @@ pub fn parse_physical_expr( )?, e.list .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?, &e.negated, input_schema, @@ -278,7 +286,7 @@ pub fn parse_physical_expr( ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -301,7 +309,7 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -334,7 +342,7 @@ pub fn parse_physical_expr( let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; // TODO Do not create new the ExecutionProps @@ -348,19 +356,22 @@ pub fn parse_physical_expr( )? } ExprType::ScalarUdf(e) => { - let udf = registry.udf(e.name.as_str())?; + let udf = match &e.fun_definition { + Some(buf) => codec.try_decode_udf(&e.name, buf)?, + None => registry.udf(e.name.as_str())?, + }; let signature = udf.signature(); - let scalar_fun = udf.fun().clone(); + let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; Arc::new(ScalarFunctionExpr::new( e.name.as_str(), - scalar_fun, + scalar_fun_def, args, convert_required!(e.return_type)?, None, @@ -394,7 +405,8 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, ) -> Result> { - expr.map(|e| parse_physical_expr(e, registry, input_schema)) + let codec = DefaultPhysicalExtensionCodec {}; + expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -439,10 +451,11 @@ pub fn parse_protobuf_hash_partitioning( ) -> Result> { match partitioning { Some(hash_part) => { + let codec = DefaultPhysicalExtensionCodec {}; let expr = hash_part .hash_expr .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>, _>>()?; Ok(Some(Partitioning::Hash( @@ -503,6 +516,7 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { + let codec = DefaultPhysicalExtensionCodec {}; let sort_expr = node_collection .physical_sort_expr_nodes .iter() @@ -510,7 +524,7 @@ pub fn parse_protobuf_file_scan_config( let expr = node .expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, &schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, &schema, &codec)) .unwrap()?; Ok(PhysicalSortExpr { expr, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 004948da938f..da31c5e762bc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -20,6 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::from_proto::parse_physical_window_expr; +use self::to_proto::serialize_physical_expr; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; @@ -138,7 +139,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr(expr, registry, input.schema().as_ref())?, + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + )?, name.to_string(), )) }) @@ -156,7 +162,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .transpose()? .ok_or_else(|| { @@ -208,6 +219,7 @@ impl AsExecutionPlan for PhysicalPlanNode { expr, registry, base_config.file_schema.as_ref(), + extension_codec, ) }) .transpose()?; @@ -254,7 +266,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .hash_expr .iter() .map(|e| { - parse_physical_expr(e, registry, input.schema().as_ref()) + parse_physical_expr( + e, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>, _>>()?; @@ -329,7 +346,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>>>()?; @@ -396,8 +418,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -406,8 +433,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -434,7 +466,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| { expr.expr .as_ref() - .map(|e| parse_physical_expr(e, registry, &physical_schema)) + .map(|e| { + parse_physical_expr( + e, + registry, + &physical_schema, + extension_codec, + ) + }) .transpose() }) .collect::, _>>()?; @@ -451,7 +490,7 @@ impl AsExecutionPlan for PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); let ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { @@ -524,11 +563,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -555,6 +596,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -635,11 +677,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -666,6 +710,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -805,7 +850,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -852,7 +897,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -916,6 +961,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -1088,7 +1134,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| expr.0.clone().try_into()) + .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1128,7 +1174,10 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(exec.predicate().clone().try_into()?), + expr: Some(serialize_physical_expr( + exec.predicate().clone(), + extension_codec, + )?), default_filter_selectivity: exec.default_selectivity() as u32, }, ))), @@ -1183,8 +1232,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1196,7 +1245,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1254,8 +1306,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1267,7 +1319,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1304,7 +1359,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1321,7 +1379,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1423,14 +1484,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1512,7 +1573,7 @@ impl AsExecutionPlan for PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| pred.clone().try_into()) + .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1559,7 +1620,9 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.clone().try_into()) + .map(|expr| { + serialize_physical_expr(expr.clone(), extension_codec) + }) .collect::>>()?, partition_count: *partition_count as u64, }) @@ -1592,7 +1655,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1658,7 +1724,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1695,7 +1764,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1743,7 +1815,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1773,7 +1845,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1816,7 +1888,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ba77b30b7f8d..b66709d0c5bd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, @@ -31,13 +30,10 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::datasource::{ - file_format::csv::CsvSink, - file_format::json::JsonSink, - listing::{FileRange, PartitionedFile}, - physical_plan::FileScanConfig, - physical_plan::FileSinkConfig, -}; + +use datafusion_expr::ScalarFunctionDefinition; + +use crate::logical_plan::csv_writer_options_to_proto; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; @@ -46,16 +42,24 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, - Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, - Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, - TryCastExpr, Variance, VariancePop, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, + Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::{ + datasource::{ + file_format::{csv::CsvSink, json::JsonSink}, + listing::{FileRange, PartitionedFile}, + physical_plan::{FileScanConfig, FileSinkConfig}, + }, + physical_plan::expressions::LikeExpr, +}; use datafusion_common::config::{ ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, TableParquetOptions, @@ -68,14 +72,17 @@ use datafusion_common::{ DataFusionError, JoinSide, Result, }; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let expressions: Vec = a .expressions() .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), &codec)) .collect::>>()?; let ordering_req: Vec = a @@ -237,16 +244,16 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - + let codec = DefaultPhysicalExtensionCodec {}; let args = args .into_iter() - .map(|e| e.try_into()) + .map(|e| serialize_physical_expr(e, &codec)) .collect::>>()?; let partition_by = window_expr .partition_by() .iter() - .map(|p| p.clone().try_into()) + .map(|p| serialize_physical_expr(p.clone(), &codec)) .collect::>>()?; let order_by = window_expr @@ -374,195 +381,250 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -impl TryFrom> for protobuf::PhysicalExprNode { - type Error = DataFusionError; - - fn try_from(value: Arc) -> Result { - let expr = value.as_any(); - - if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Column( - protobuf::PhysicalColumn { - name: expr.name().to_string(), - index: expr.index() as u32, - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(expr.left().to_owned().try_into()?)), - r: Some(Box::new(expr.right().to_owned().try_into()?)), - op: format!("{:?}", expr.op()), - }); +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]) +pub fn serialize_physical_expr( + value: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expr = value.as_any(); + + if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, + }, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { + l: Some(Box::new(serialize_physical_expr( + expr.left().clone(), + codec, + )?)), + r: Some(Box::new(serialize_physical_expr( + expr.right().clone(), + codec, + )?)), + op: format!("{:?}", expr.op()), + }); - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( - binary_expr, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::Case( - Box::new( - protobuf::PhysicalCaseNode { - expr: expr - .expr() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, - Self::Error, - >>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ), + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( + binary_expr, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .map(|exp| { + serialize_physical_expr(exp.clone(), codec) + .map(Box::new) + }) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr, codec) + }) + .collect::, + DataFusionError, + >>()?, + else_expr: expr + .else_expr() + .map(|a| { + serialize_physical_expr(a.clone(), codec) + .map(Box::new) + }) + .transpose()?, + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( - Box::new(protobuf::PhysicalNot { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::InList( - Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - list: expr - .list() - .iter() - .map(|a| a.clone().try_into()) - .collect::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( + protobuf::PhysicalNot { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::InList( + Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + list: expr + .list() + .iter() + .map(|a| serialize_physical_expr(a.clone(), codec)) + .collect::, - Self::Error, + DataFusionError, >>()?, - negated: expr.negated(), - }, - ), + negated: expr.negated(), + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( - Box::new(protobuf::PhysicalNegativeNode { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, - )), - }) - } else if let Some(cast) = expr.downcast_ref::() { + ), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( + protobuf::PhysicalNegativeNode { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(lit) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( + lit.value().try_into()?, + )), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let args: Vec = expr + .args() + .iter() + .map(|e| serialize_physical_expr(e.to_owned(), codec)) + .collect::, _>>()?; + if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { + let fun: protobuf::ScalarFunction = (&fun).try_into()?; + Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( - protobuf::PhysicalCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), + fun: fun.into(), + args, + return_type: Some(expr.return_type().try_into()?), }, - ))), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( - Box::new(protobuf::PhysicalTryCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), - }), )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let args: Vec = expr - .args() - .iter() - .map(|e| e.to_owned().try_into()) - .collect::, _>>()?; - if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { - let fun: protobuf::ScalarFunction = (&fun).try_into()?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::ScalarFunction( - protobuf::PhysicalScalarFunctionNode { - name: expr.name().to_string(), - fun: fun.into(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - ), - ), - }) - } else { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( - protobuf::PhysicalScalarUdfNode { - name: expr.name().to_string(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - )), - }) + } else { + let mut buf = Vec::new(); + match expr.fun() { + ScalarFunctionDefinition::UDF(udf) => { + codec.try_encode_udf(udf, &mut buf)?; + } + _ => { + return not_impl_err!( + "Proto serialization error: Trying to serialize a unresolved function" + ); + } } - } else if let Some(expr) = expr.downcast_ref::() { + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr( - Box::new(protobuf::PhysicalLikeExprNode { - negated: expr.negated(), - case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - pattern: Some(Box::new(expr.pattern().to_owned().try_into()?)), - }), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( + protobuf::PhysicalScalarUdfNode { + name: expr.name().to_string(), + args, + fun_definition, + return_type: Some(expr.return_type().try_into()?), + }, )), }) - } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern().to_owned(), + codec, + )?)), + }, + ))), + }) + } else { + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, + codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(when_expr.clone().try_into()?), - then_expr: Some(then_expr.clone().try_into()?), + when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), + then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), }) } @@ -683,6 +745,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { fn try_from( conf: &FileScanConfig, ) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let file_groups = conf .file_groups .iter() @@ -694,7 +757,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let expr_node_vec = order .iter() .map(|sort_expr| { - let expr = sort_expr.expr.clone().try_into()?; + let expr = serialize_physical_expr(sort_expr.expr.clone(), &codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !sort_expr.options.descending, @@ -757,10 +820,11 @@ impl TryFrom>> for protobuf::MaybeFilter { type Error = DataFusionError; fn try_from(expr: Option>) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(expr.try_into()?), + expr: Some(serialize_physical_expr(expr, &codec)?), }), } } @@ -786,8 +850,9 @@ impl TryFrom for protobuf::PhysicalSortExprNode { type Error = DataFusionError; fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { + let codec = DefaultPhysicalExtensionCodec {}; Ok(PhysicalSortExprNode { - expr: Some(Box::new(sort_expr.expr.try_into()?)), + expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, &codec)?)), asc: !sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f0c6286a19d..4924128ae190 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -32,7 +33,7 @@ use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; -use datafusion::execution::context::ExecutionProps; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; @@ -49,7 +50,6 @@ use datafusion::physical_plan::expressions::{ NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -73,13 +73,19 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, - SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -603,14 +609,11 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let execution_props = ExecutionProps::new(); - - let fun_expr = - functions::create_physical_fun(&BuiltinScalarFunction::Sin, &execution_props)?; + let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Sin); let expr = ScalarFunctionExpr::new( "sin", - fun_expr, + fun_def, vec![col("a", &schema)?], DataType::Float64, None, @@ -646,9 +649,11 @@ fn roundtrip_scalar_udf() -> Result<()> { scalar_fn.clone(), ); + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let expr = ScalarFunctionExpr::new( "dummy", - scalar_fn, + fun_def, vec![col("a", &schema)?], DataType::Int64, None, @@ -665,6 +670,134 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), ctx) } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + #[derive(Debug)] + struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, + } + + impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } + } + + /// Implement the ScalarUDFImpl trait for MyRegexUdf + impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } + + #[derive(Debug)] + pub struct ScalarUDFExtensionCodec {} + + impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode regex_udf: {}", + err + )) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + } + Ok(()) + } + } + + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = ScalarFunctionExpr::new( + udf.name(), + ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + vec![], + DataType::Int32, + None, + false, + ); + let fmt_expr = format!("{test_expr:?}"); + let ctx = SessionContext::new(); + + ctx.register_udf(udf.clone()); + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::PhysicalExprNode = + match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { + Ok(proto) => proto, + Err(e) => panic!("failed to serialize expr: {e:?}"), + }; + let field_a = Field::new("a", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![field_a])); + let round_trip = + parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); + assert_eq!(fmt_expr, format!("{round_trip:?}")); +} #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From 8074ca1e758470319699a562074290906003b312 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 19 Mar 2024 12:14:13 -0600 Subject: [PATCH 19/35] Support Union types in `ScalarValue` (#9683) Support Union types in `ScalarValue` (#9683) --- datafusion/common/src/error.rs | 4 +- datafusion/common/src/scalar/mod.rs | 82 ++++++ datafusion/physical-plan/src/filter.rs | 35 +++ datafusion/proto/proto/datafusion.proto | 15 + datafusion/proto/src/generated/pbjson.rs | 272 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 26 +- .../proto/src/logical_plan/from_proto.rs | 35 +++ datafusion/proto/src/logical_plan/to_proto.rs | 29 ++ datafusion/sql/src/unparser/expr.rs | 1 + 9 files changed, 496 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 1ecd5b62bee8..d1e47b473499 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -63,7 +63,7 @@ pub enum DataFusionError { IoError(io::Error), /// Error when SQL is syntactically incorrect. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace SQL(ParserError, Option), /// Error when a feature is not yet implemented. /// @@ -101,7 +101,7 @@ pub enum DataFusionError { /// This error can be returned in cases such as when schema inference is not /// possible and when column names are not unique. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace /// Boxing the optional backtrace to prevent SchemaError(SchemaError, Box>), /// Error during execution of the query. diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index a2484e93e812..d33b8b6e142c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -53,6 +53,8 @@ use arrow::{ }, }; use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_buffer::Buffer; +use arrow_schema::{UnionFields, UnionMode}; pub use struct_builder::ScalarStructBuilder; @@ -275,6 +277,11 @@ pub enum ScalarValue { DurationMicrosecond(Option), /// Duration in nanoseconds DurationNanosecond(Option), + /// A nested datatype that can represent slots of differing types. Components: + /// `.0`: a tuple of union `type_id` and the single value held by this Scalar + /// `.1`: the list of fields, zero-to-one of which will by set in `.0` + /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came + Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), } @@ -375,6 +382,10 @@ impl PartialEq for ScalarValue { (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), (IntervalMonthDayNano(_), _) => false, + (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => { + val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) + } + (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, (Null, Null) => true, @@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue { (DurationMicrosecond(_), _) => None, (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), (DurationNanosecond(_), _) => None, + (Union(v1, t1, m1), Union(v2, t2, m2)) => { + if t1.eq(t2) && m1.eq(m2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) if k1 == k2 { @@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue { IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), + Union(v, t, m) => { + v.hash(state); + t.hash(state); + m.hash(state); + } Dictionary(k, v) => { k.hash(state); v.hash(state); @@ -1093,6 +1117,7 @@ impl ScalarValue { ScalarValue::DurationNanosecond(_) => { DataType::Duration(TimeUnit::Nanosecond) } + ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -1292,6 +1317,7 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), + ScalarValue::Union(v, _, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -2087,6 +2113,39 @@ impl ScalarValue { e, size ), + ScalarValue::Union(value, fields, _mode) => match value { + Some((v_id, value)) => { + let mut field_type_ids = Vec::::with_capacity(fields.len()); + let mut child_arrays = + Vec::<(Field, ArrayRef)>::with_capacity(fields.len()); + for (f_id, field) in fields.iter() { + let ar = if f_id == *v_id { + value.to_array_of_size(size)? + } else { + let dt = field.data_type(); + new_null_array(dt, size) + }; + let field = (**field).clone(); + child_arrays.push((field, ar)); + field_type_ids.push(f_id); + } + let type_ids = repeat(*v_id).take(size).collect::>(); + let type_ids = Buffer::from_slice_ref(type_ids); + let value_offsets: Option = None; + let ar = UnionArray::try_new( + field_type_ids.as_slice(), + type_ids, + value_offsets, + child_arrays, + ) + .map_err(|e| DataFusionError::ArrowError(e, None))?; + Arc::new(ar) + } + None => { + let dt = self.data_type(); + new_null_array(&dt, size) + } + }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { @@ -2626,6 +2685,9 @@ impl ScalarValue { ScalarValue::DurationNanosecond(val) => { eq_array_primitive!(array, index, DurationNanosecondArray, val)? } + ScalarValue::Union(_, _, _) => { + return _not_impl_err!("Union is not supported yet") + } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { DataType::Int8 => get_dict_value::(array, index)?, @@ -2703,6 +2765,15 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Union(vals, fields, _mode) => { + vals.as_ref() + .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .unwrap_or_default() + // `fields` is boxed, so it is NOT already included in `self` + + std::mem::size_of_val(fields) + + (std::mem::size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() @@ -3048,6 +3119,9 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( @@ -3164,6 +3238,10 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "{}:{}", id, val)?, + None => write!(f, "NULL")?, + }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; @@ -3279,6 +3357,10 @@ impl fmt::Debug for ScalarValue { ScalarValue::DurationNanosecond(_) => { write!(f, "DurationNanosecond(\"{self}\")") } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "Union {}:{}", id, val), + None => write!(f, "Union(NULL)"), + }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 72f885a93962..f44ade7106df 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -441,7 +441,9 @@ mod tests { use crate::test::exec::StatisticsExec; use crate::ExecutionPlan; + use crate::empty::EmptyExec; use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_expr::Operator; @@ -1090,4 +1092,37 @@ mod tests { assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) } + + #[test] + fn test_equivalence_properties_union_type() -> Result<()> { + let union_type = DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", union_type, true), + ])); + + let exec = FilterExec::try_new( + binary( + binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?, + Operator::And, + binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, + &schema, + )?, + Arc::new(EmptyExec::new(schema.clone())), + )?; + + exec.statistics().unwrap(); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7a9b427ce7d3..10f79a2b8cc8 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue { int64 nanos = 3; } +message UnionField { + int32 field_id = 1; + Field field = 2; +} + +message UnionValue { + // Note that a null union value must have one or more fields, so we + // encode a null UnionValue as one with value_id == 128 + int32 value_id = 1; + ScalarValue value = 2; + repeated UnionField fields = 3; + UnionMode mode = 4; +} + message ScalarFixedSizeBinary{ bytes values = 1; int32 length = 2; @@ -1042,6 +1056,7 @@ message ScalarValue{ ScalarTime64Value time64_value = 30; IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; + UnionValue union_value = 42; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index fd27520b3be0..7757a64ef359 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -24053,6 +24053,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::FixedSizeBinaryValue(v) => { struct_ser.serialize_field("fixedSizeBinaryValue", v)?; } + scalar_value::Value::UnionValue(v) => { + struct_ser.serialize_field("unionValue", v)?; + } } } struct_ser.end() @@ -24137,6 +24140,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano", "fixed_size_binary_value", "fixedSizeBinaryValue", + "union_value", + "unionValue", ]; #[allow(clippy::enum_variant_names)] @@ -24177,6 +24182,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Time64Value, IntervalMonthDayNano, FixedSizeBinaryValue, + UnionValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24234,6 +24240,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), + "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24483,6 +24490,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) +; + } + GeneratedField::UnionValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("unionValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) ; } } @@ -26942,6 +26956,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_id != 0 { + len += 1; + } + if self.field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionField", len)?; + if self.field_id != 0 { + struct_ser.serialize_field("fieldId", &self.field_id)?; + } + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_id", + "fieldId", + "field", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldId, + Field, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldId" | "field_id" => Ok(GeneratedField::FieldId), + "field" => Ok(GeneratedField::Field), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_id__ = None; + let mut field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldId => { + if field_id__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldId")); + } + field_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + } + } + Ok(UnionField { + field_id: field_id__.unwrap_or_default(), + field: field__, + }) + } + } + deserializer.deserialize_struct("datafusion.UnionField", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UnionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -27104,6 +27229,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode { deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value_id != 0 { + len += 1; + } + if self.value.is_some() { + len += 1; + } + if !self.fields.is_empty() { + len += 1; + } + if self.mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionValue", len)?; + if self.value_id != 0 { + struct_ser.serialize_field("valueId", &self.value_id)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } + if self.mode != 0 { + let v = UnionMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + struct_ser.serialize_field("mode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value_id", + "valueId", + "value", + "fields", + "mode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ValueId, + Value, + Fields, + Mode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "valueId" | "value_id" => Ok(GeneratedField::ValueId), + "value" => Ok(GeneratedField::Value), + "fields" => Ok(GeneratedField::Fields), + "mode" => Ok(GeneratedField::Mode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value_id__ = None; + let mut value__ = None; + let mut fields__ = None; + let mut mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ValueId => { + if value_id__.is_some() { + return Err(serde::de::Error::duplicate_field("valueId")); + } + value_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + GeneratedField::Mode => { + if mode__.is_some() { + return Err(serde::de::Error::duplicate_field("mode")); + } + mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(UnionValue { + value_id: value_id__.unwrap_or_default(), + value: value__, + fields: fields__.unwrap_or_default(), + mode: mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UniqueConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 16ad2b848db9..ab0ddb14ebfc 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1225,6 +1225,28 @@ pub struct IntervalMonthDayNanoValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -1236,7 +1258,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -1320,6 +1342,8 @@ pub mod scalar_value { IntervalMonthDayNano(super::IntervalMonthDayNanoValue), #[prost(message, tag = "34")] FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 06aab16edd57..8581156e2bb8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -768,6 +768,41 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), )), + Value::UnionValue(val) => { + let mode = match val.mode { + 0 => UnionMode::Sparse, + 1 => UnionMode::Dense, + id => Err(Error::unknown("UnionMode", id))?, + }; + let ids = val + .fields + .iter() + .map(|f| f.field_id as i8) + .collect::>(); + let fields = val + .fields + .iter() + .map(|f| f.field.clone()) + .collect::>>(); + let fields = fields.ok_or_else(|| Error::required("UnionField"))?; + let fields = fields + .iter() + .map(Field::try_from) + .collect::, _>>()?; + let fields = UnionFields::new(ids, fields); + let v_id = val.value_id as i8; + let val = match &val.value { + None => None, + Some(val) => { + let val: ScalarValue = val + .as_ref() + .try_into() + .map_err(|_| Error::General("Invalid Scalar".to_string()))?; + Some((v_id, Box::new(val))) + } + }; + Self::Union(val, fields, mode) + } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 92015594906b..05a29ff6d42b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use crate::protobuf::{ }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, + UnionField, UnionValue, }; use arrow::{ @@ -1405,6 +1406,34 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }; Ok(protobuf::ScalarValue { value: Some(value) }) } + + ScalarValue::Union(val, df_fields, mode) => { + let mut fields = Vec::::with_capacity(df_fields.len()); + for (id, field) in df_fields.iter() { + let field_id = id as i32; + let field = Some(field.as_ref().try_into()?); + let field = UnionField { field_id, field }; + fields.push(field); + } + let mode = match mode { + UnionMode::Sparse => 0, + UnionMode::Dense => 1, + }; + let value = match val { + None => None, + Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), + }; + let val = UnionValue { + value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), + value, + fields, + mode, + }; + let val = Value::UnionValue(Box::new(val)); + let val = protobuf::ScalarValue { value: Some(val) }; + Ok(val) + } + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index c26e8481ce43..43f3e348dc32 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -456,6 +456,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } From ad8d552b9f150c3c066b0764e84f72b667a649ff Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Tue, 19 Mar 2024 22:09:20 +0100 Subject: [PATCH 20/35] parquet: Add support for row group pruning on FixedSizeBinary (#9646) * Add support for row group pruning on FixedSizeBinary * Check statistics values are valid for their type --- .../physical_plan/parquet/row_groups.rs | 1 + .../physical_plan/parquet/statistics.rs | 27 ++++- .../core/tests/parquet/row_group_pruning.rs | 101 ++++++++++++++++++ 3 files changed, 127 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 9cd46994960f..a82c5d97a2b7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -226,6 +226,7 @@ impl PruningStatistics for BloomFilterStatistics { match value { ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), ScalarValue::Binary(Some(v)) => sbbf.check(v), + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), ScalarValue::Boolean(Some(v)) => sbbf.check(v), ScalarValue::Float64(Some(v)) => sbbf.check(v), ScalarValue::Float32(Some(v)) => sbbf.check(v), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 4e472606da51..aac5aff80f16 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -105,14 +105,20 @@ macro_rules! get_statistic { let s = std::str::from_utf8(s.$bytes_func()) .map(|s| s.to_string()) .ok(); + if s.is_none() { + log::debug!( + "Utf8 statistics is a non-UTF8 value, ignoring it." + ); + } Some(ScalarValue::Utf8(s)) } } } - // type not supported yet + // type not fully supported yet ParquetStatistics::FixedLenByteArray(s) => { match $target_arrow_type { - // just support the decimal data type + // just support specific logical data types, there are others each + // with their own ordering Some(DataType::Decimal128(precision, scale)) => { Some(ScalarValue::Decimal128( Some(from_bytes_to_i128(s.$bytes_func())), @@ -120,6 +126,23 @@ macro_rules! get_statistic { *scale, )) } + Some(DataType::FixedSizeBinary(size)) => { + let value = s.$bytes_func().to_vec(); + let value = if value.len().try_into() == Ok(*size) { + Some(value) + } else { + log::debug!( + "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", + size, + value.len(), + ); + None + }; + Some(ScalarValue::FixedSizeBinary( + *size, + value, + )) + } _ => None, } } diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 55112193502d..ed48d040648c 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -948,6 +948,107 @@ async fn prune_binary_lt() { .await; } +#[tokio::test] +async fn prune_fixedsizebinary_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('fe6' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('fe6' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('be9' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'mixed' batch: 'be1' < 'be9' < 'fe4' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize != ARROW_CAST(CAST('be1' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize < ARROW_CAST(CAST('be3' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize < ARROW_CAST(CAST('be9' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + #[tokio::test] async fn prune_periods_in_column_names() { // There are three row groups for "service.name", each with 5 rows = 15 rows total From 89efc4a7e06bd0295ca72dd6ec5fe987d1ac246b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 08:58:02 -0400 Subject: [PATCH 21/35] Minor: Add documentation about LogicalPlan::expressions (#9698) --- datafusion/expr/src/logical_plan/extension.rs | 9 +++++---- datafusion/expr/src/logical_plan/plan.rs | 14 +++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index f87ca45f14be..bb2c932ce391 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -53,10 +53,11 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; - /// Returns all expressions in the current logical plan node. This - /// should not include expressions of any inputs (aka - /// non-recursively). These expressions are used for optimizer - /// passes and rewrites. + /// Returns all expressions in the current logical plan node. This should + /// not include expressions of any inputs (aka non-recursively). + /// + /// These expressions are used for optimizer + /// passes and rewrites. See [`LogicalPlan::expressions`] for more details. fn expressions(&self) -> Vec; /// A list of output columns (e.g. the names of columns in diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 08fe3380061f..05d7ac539458 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -234,9 +234,17 @@ impl LogicalPlan { ]) } - /// returns all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children + /// Returns all expressions (non-recursively) evaluated by the current + /// logical plan node. This does not include expressions in any children + /// + /// The returned expressions do not necessarily represent or even + /// contributed to the output schema of this node. For example, + /// `LogicalPlan::Filter` returns the filter expression even though the + /// output of a Filter has the same columns as the input. + /// + /// The expressions do contain all the columns that are used by this plan, + /// so if there are columns not referenced by these expressions then + /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; self.inspect_expressions(|e| { From 1d0171ab9d33fc7896861dee85804d7daf0a6390 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 20 Mar 2024 08:24:33 -0700 Subject: [PATCH 22/35] Make builtin window function output datatype to be derived from schema (#9686) * Make builtin window function output datatype to be derived from schema --- datafusion/core/src/physical_planner.rs | 22 ++++----- .../core/tests/fuzz_cases/window_fuzz.rs | 39 +++++++++++++-- datafusion/physical-plan/src/windows/mod.rs | 47 ++++++++++--------- 3 files changed, 72 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ee581ca64214..ca708b05823e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -742,13 +742,13 @@ impl DefaultPhysicalPlanner { ); } - let logical_input_schema = input.schema(); + let logical_schema = logical_plan.schema(); let window_expr = window_expr .iter() .map(|e| { create_window_expr( e, - logical_input_schema, + logical_schema, session_state.execution_props(), ) }) @@ -1578,11 +1578,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { pub fn create_window_expr_with_name( e: &Expr, name: impl Into, - logical_input_schema: &DFSchema, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); - let physical_input_schema: &Schema = &logical_input_schema.into(); + let physical_schema: &Schema = &logical_schema.into(); match e { Expr::WindowFunction(WindowFunction { fun, @@ -1594,17 +1594,15 @@ pub fn create_window_expr_with_name( }) => { let args = args .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .map(|e| create_physical_expr(e, logical_schema, execution_props)) .collect::>>()?; let partition_by = partition_by .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .map(|e| create_physical_expr(e, logical_schema, execution_props)) .collect::>>()?; let order_by = order_by .iter() - .map(|e| { - create_physical_sort_expr(e, logical_input_schema, execution_props) - }) + .map(|e| create_physical_sort_expr(e, logical_schema, execution_props)) .collect::>>()?; if !is_window_frame_bound_valid(window_frame) { @@ -1625,7 +1623,7 @@ pub fn create_window_expr_with_name( &partition_by, &order_by, window_frame, - physical_input_schema, + physical_schema, ignore_nulls, ) } @@ -1636,7 +1634,7 @@ pub fn create_window_expr_with_name( /// Create a window expression from a logical expression or an alias pub fn create_window_expr( e: &Expr, - logical_input_schema: &DFSchema, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" @@ -1644,7 +1642,7 @@ pub fn create_window_expr( Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), _ => (e.display_name()?, e), }; - create_window_expr_with_name(e, name, logical_input_schema, execution_props) + create_window_expr_with_name(e, name, logical_schema, execution_props) } type AggregateExprWithOptionalArgs = ( diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 00c65995a5ff..2514324a9541 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ @@ -39,6 +40,7 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -273,6 +275,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { window_frame.is_causal() }; + let extended_schema = + schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + let window_expr = create_window_expr( &window_fn, fn_name.to_string(), @@ -280,7 +285,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), - schema.as_ref(), + &extended_schema, false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( @@ -678,6 +683,8 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } + let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &window_fn, @@ -686,7 +693,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, false, )?], exec1, @@ -704,7 +711,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, false, )?], exec2, @@ -747,6 +754,32 @@ async fn run_window_test( Ok(()) } +// The planner has fully updated schema before calling the `create_window_expr` +// Replicate the same for this test +fn schema_add_window_fields( + args: &[Arc], + schema: &Arc, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + true, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index da2b24487d02..21f42f41fb5c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -174,20 +174,15 @@ fn create_built_in_window_expr( name: String, ignore_nulls: bool, ) -> Result> { - // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args - .iter() - .map(|arg| arg.data_type(input_schema)) - .collect::>()?; + // derive the output datatype from incoming schema + let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); - // figure out the output type - let data_type = &fun.return_type(&input_types)?; Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)), - BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)), + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)), + BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), + BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), + BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { DataFusionError::Execution( @@ -201,13 +196,13 @@ fn create_built_in_window_expr( if n.is_unsigned() { let n: u64 = n.try_into()?; - Arc::new(Ntile::new(name, n, data_type)) + Arc::new(Ntile::new(name, n, out_data_type)) } else { let n: i64 = n.try_into()?; if n <= 0 { return exec_err!("NTILE requires a positive integer"); } - Arc::new(Ntile::new(name, n as u64, data_type)) + Arc::new(Ntile::new(name, n as u64, out_data_type)) } } BuiltInWindowFunction::Lag => { @@ -216,10 +211,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lag( name, - data_type.clone(), + out_data_type.clone(), arg, shift_offset, default_value, @@ -232,10 +227,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lead( name, - data_type.clone(), + out_data_type.clone(), arg, shift_offset, default_value, @@ -252,18 +247,28 @@ fn create_built_in_window_expr( Arc::new(NthValue::nth( name, arg, - data_type.clone(), + out_data_type.clone(), n, ignore_nulls, )?) } BuiltInWindowFunction::FirstValue => { let arg = args[0].clone(); - Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls)) + Arc::new(NthValue::first( + name, + arg, + out_data_type.clone(), + ignore_nulls, + )) } BuiltInWindowFunction::LastValue => { let arg = args[0].clone(); - Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls)) + Arc::new(NthValue::last( + name, + arg, + out_data_type.clone(), + ignore_nulls, + )) } }) } From 3bf06d3cc40657d38ab3425dca1945e4592d2d05 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Wed, 20 Mar 2024 12:19:33 -0700 Subject: [PATCH 23/35] Issue-9660 - Extract array_to_string and string_to_array from kernels and udf containers (#9704) --- datafusion/functions-array/src/kernels.rs | 329 +-------------- datafusion/functions-array/src/lib.rs | 9 +- datafusion/functions-array/src/string.rs | 479 ++++++++++++++++++++++ datafusion/functions-array/src/udf.rs | 137 +------ datafusion/functions-array/src/utils.rs | 12 + 5 files changed, 502 insertions(+), 464 deletions(-) create mode 100644 datafusion/functions-array/src/string.rs diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 15cdf8f279ae..ec0942837795 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -18,10 +18,8 @@ //! implementation kernels for array functions use arrow::array::{ - Array, ArrayRef, BooleanArray, Capacities, Date32Array, Float32Array, Float64Array, - GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeListArray, - LargeStringArray, ListArray, ListBuilder, MutableArrayData, OffsetSizeTrait, - StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Array, ArrayRef, BooleanArray, Capacities, Date32Array, GenericListArray, Int64Array, + LargeListArray, ListArray, MutableArrayData, OffsetSizeTrait, UInt64Array, }; use arrow::compute; use arrow::datatypes::{ @@ -33,335 +31,18 @@ use arrow_schema::FieldRef; use arrow_schema::SortOptions; use datafusion_common::cast::{ - as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array, - as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array, - as_string_array, + as_date32_array, as_generic_list_array, as_int64_array, as_interval_mdn_array, + as_large_list_array, as_list_array, as_null_array, as_string_array, }; use datafusion_common::{ exec_err, internal_datafusion_err, not_impl_datafusion_err, DataFusionError, Result, ScalarValue, }; +use crate::utils::downcast_arg; use std::any::type_name; use std::sync::Arc; -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - -macro_rules! call_array_function { - ($DATATYPE:expr, false) => { - match $DATATYPE { - DataType::Utf8 => array_function!(StringArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), - } - }; - ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ - match $DATATYPE { - DataType::List(_) => array_function!(ListArray), - DataType::Utf8 => array_function!(StringArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), - } - }}; -} - -/// Array_to_string SQL function -pub(super) fn array_to_string(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("array_to_string expects two or three arguments"); - } - - let arr = &args[0]; - - let delimiters = as_string_array(&args[1])?; - let delimiters: Vec> = delimiters.iter().collect(); - - let mut null_string = String::from(""); - let mut with_null_string = false; - if args.len() == 3 { - null_string = as_string_array(&args[2])?.value(0).to_string(); - with_null_string = true; - } - - fn compute_array_to_string( - arg: &mut String, - arr: ArrayRef, - delimiter: String, - null_string: String, - with_null_string: bool, - ) -> datafusion_common::Result<&mut String> { - match arr.data_type() { - DataType::List(..) => { - let list_array = as_list_array(&arr)?; - for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } - - Ok(arg) - } - DataType::LargeList(..) => { - let list_array = as_large_list_array(&arr)?; - for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } - - Ok(arg) - } - DataType::Null => Ok(arg), - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - to_string!( - arg, - arr, - &delimiter, - &null_string, - with_null_string, - $ARRAY_TYPE - ) - }; - } - call_array_function!(data_type, false) - } - } - } - - fn generate_string_array( - list_arr: &GenericListArray, - delimiters: Vec>, - null_string: String, - with_null_string: bool, - ) -> datafusion_common::Result { - let mut res: Vec> = Vec::new(); - for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - let mut arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } - } else { - res.push(None); - } - } - - Ok(StringArray::from(res)) - } - - let arr_type = arr.data_type(); - let string_arr = match arr_type { - DataType::List(_) | DataType::FixedSizeList(_, _) => { - let list_array = as_list_array(&arr)?; - generate_string_array::( - list_array, - delimiters, - null_string, - with_null_string, - )? - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&arr)?; - generate_string_array::( - list_array, - delimiters, - null_string, - with_null_string, - )? - } - _ => { - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - // delimiter length is 1 - assert_eq!(delimiters.len(), 1); - let delimiter = delimiters[0].unwrap(); - let s = compute_array_to_string( - &mut arg, - arr.clone(), - delimiter.to_string(), - null_string, - with_null_string, - )? - .clone(); - - if !s.is_empty() { - let s = s.strip_suffix(delimiter).unwrap().to_string(); - res.push(Some(s)); - } else { - res.push(Some(s)); - } - StringArray::from(res) - } - }; - - Ok(Arc::new(string_arr)) -} - -/// Splits string at occurrences of delimiter and returns an array of parts -/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -pub fn string_to_array(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("string_to_array expects two or three arguments"); - } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - - let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - )); - - match args.len() { - 2 => { - string_array.iter().zip(delimiter_array.iter()).for_each( - |(string, delimiter)| { - match (string, delimiter) { - (Some(string), Some("")) => { - list_builder.values().append_value(string); - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - list_builder.values().append_value(s); - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c); - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value - } - }, - ); - } - - 3 => { - let null_value_array = as_generic_string_array::(&args[2])?; - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(string); - } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(c); - } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value - } - }); - } - _ => { - return exec_err!( - "Expect string_to_array function to take two or three parameters" - ) - } - } - - let list_array = list_builder.finish(); - Ok(Arc::new(list_array) as ArrayRef) -} - /// Generates an array of integers from start to stop with a given step. /// /// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index fb16acdef2bd..f8d85800b3e3 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -39,6 +39,7 @@ mod remove; mod replace; mod rewrite; mod set_ops; +mod string; mod udf; mod utils; @@ -73,6 +74,8 @@ pub mod expr_fn { pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; + pub use super::string::array_to_string; + pub use super::string::string_to_array; pub use super::udf::array_dims; pub use super::udf::array_empty; pub use super::udf::array_length; @@ -81,19 +84,17 @@ pub mod expr_fn { pub use super::udf::array_resize; pub use super::udf::array_reverse; pub use super::udf::array_sort; - pub use super::udf::array_to_string; pub use super::udf::cardinality; pub use super::udf::flatten; pub use super::udf::gen_series; pub use super::udf::range; - pub use super::udf::string_to_array; } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ - udf::array_to_string_udf(), - udf::string_to_array_udf(), + string::array_to_string_udf(), + string::string_to_array_udf(), udf::range_udf(), udf::gen_series_udf(), udf::array_dims_udf(), diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs new file mode 100644 index 000000000000..3140866f5ff6 --- /dev/null +++ b/datafusion/functions-array/src/string.rs @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_to_string and string_to_array functions. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, ListBuilder, + OffsetSizeTrait, StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use arrow::datatypes::{DataType, Field}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{Expr, TypeSignature}; + +use datafusion_common::{plan_err, DataFusionError, Result}; + +use std::any::{type_name, Any}; + +use crate::utils::{downcast_arg, make_scalar_function}; +use arrow_schema::DataType::{LargeUtf8, Utf8}; +use datafusion_common::cast::{ + as_generic_string_array, as_large_list_array, as_list_array, as_string_array, +}; +use datafusion_common::exec_err; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::sync::Arc; + +macro_rules! to_string { + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + for x in arr { + match x { + Some(x) => { + $ARG.push_str(&x.to_string()); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } + } + } + } + Ok($ARG) + }}; +} + +macro_rules! call_array_function { + ($DATATYPE:expr, false) => { + match $DATATYPE { + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }; + ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ + match $DATATYPE { + DataType::List(_) => array_function!(ListArray), + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }}; +} + +// Create static instances of ScalarUDFs for each function +make_udf_function!( + ArrayToString, + array_to_string, + array delimiter, // arg name + "converts each element to its text representation.", // doc + array_to_string_udf // internal function name +); +#[derive(Debug)] +pub(super) struct ArrayToString { + signature: Signature, + aliases: Vec, +} + +impl ArrayToString { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![ + String::from("array_to_string"), + String::from("list_to_string"), + String::from("array_join"), + String::from("list_join"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayToString { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_to_string" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, + _ => { + return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_to_string_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + StringToArray, + string_to_array, + string delimiter null_string, // arg name + "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc + string_to_array_udf // internal function name +); +#[derive(Debug)] +pub(super) struct StringToArray { + signature: Signature, + aliases: Vec, +} + +impl StringToArray { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), + TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![ + String::from("string_to_array"), + String::from("string_to_list"), + ], + } + } +} + +impl ScalarUDFImpl for StringToArray { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "string_to_array" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + Ok(match arg_types[0] { + Utf8 | LargeUtf8 => { + List(Arc::new(Field::new("item", arg_types[0].clone(), true))) + } + _ => { + return plan_err!( + "The string_to_array function can only accept Utf8 or LargeUtf8." + ); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + Utf8 => make_scalar_function(string_to_array_inner::)(args), + LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), + other => { + exec_err!("unsupported type for string_to_array function as {other}") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_to_string SQL function +pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + + let arr = &args[0]; + + let delimiters = as_string_array(&args[1])?; + let delimiters: Vec> = delimiters.iter().collect(); + + let mut null_string = String::from(""); + let mut with_null_string = false; + if args.len() == 3 { + null_string = as_string_array(&args[2])?.value(0).to_string(); + with_null_string = true; + } + + fn compute_array_to_string( + arg: &mut String, + arr: ArrayRef, + delimiter: String, + null_string: String, + with_null_string: bool, + ) -> Result<&mut String> { + match arr.data_type() { + DataType::List(..) => { + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } + DataType::LargeList(..) => { + let list_array = as_large_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } + DataType::Null => Ok(arg), + data_type => { + macro_rules! array_function { + ($ARRAY_TYPE:ident) => { + to_string!( + arg, + arr, + &delimiter, + &null_string, + with_null_string, + $ARRAY_TYPE + ) + }; + } + call_array_function!(data_type, false) + } + } + } + + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); + } else { + res.push(Some(s)); + } + } else { + res.push(None); + } + } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); + // delimiter length is 1 + assert_eq!(delimiters.len(), 1); + let delimiter = delimiters[0].unwrap(); + let s = compute_array_to_string( + &mut arg, + arr.clone(), + delimiter.to_string(), + null_string, + with_null_string, + )? + .clone(); + + if !s.is_empty() { + let s = s.strip_suffix(delimiter).unwrap().to_string(); + res.push(Some(s)); + } else { + res.push(Some(s)); + } + StringArray::from(res) + } + }; + + Ok(Arc::new(string_arr)) +} + +/// String_to_array SQL function +/// Splits string at occurrences of delimiter and returns an array of parts +/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' +pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("string_to_array expects two or three arguments"); + } + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + + let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + )); + + match args.len() { + 2 => { + string_array.iter().zip(delimiter_array.iter()).for_each( + |(string, delimiter)| { + match (string, delimiter) { + (Some(string), Some("")) => { + list_builder.values().append_value(string); + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + list_builder.values().append_value(s); + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + list_builder.values().append_value(c); + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }, + ); + } + + 3 => { + let null_value_array = as_generic_string_array::(&args[2])?; + string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(s); + } + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }); + } + _ => { + return exec_err!( + "Expect string_to_array function to take two or three parameters" + ) + } + } + + let list_array = list_builder.finish(); + Ok(Arc::new(list_array) as ArrayRef) +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index e0793900c6b3..5f5d90851758 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -17,11 +17,10 @@ //! [`ScalarUDFImpl`] definitions for array functions. -use arrow::array::{NullArray, StringArray}; use arrow::datatypes::DataType; use arrow::datatypes::Field; use arrow::datatypes::IntervalUnit::MonthDayNano; -use arrow_schema::DataType::{LargeUtf8, List, Utf8}; +use arrow_schema::DataType::List; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::Result; @@ -32,140 +31,6 @@ use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -// Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayToString, - array_to_string, - array delimiter, // arg name - "converts each element to its text representation.", // doc - array_to_string_udf // internal function name -); -#[derive(Debug)] -pub(super) struct ArrayToString { - signature: Signature, - aliases: Vec, -} - -impl ArrayToString { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("array_to_string"), - String::from("list_to_string"), - String::from("array_join"), - String::from("list_join"), - ], - } - } -} - -impl ScalarUDFImpl for ArrayToString { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_to_string" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, - _ => { - return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_to_string(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!(StringToArray, - string_to_array, - string delimiter null_string, // arg name - "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc - string_to_array_udf // internal function name -); -#[derive(Debug)] -pub(super) struct StringToArray { - signature: Signature, - aliases: Vec, -} - -impl StringToArray { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("string_to_array"), - String::from("string_to_list"), - ], - } - } -} - -impl ScalarUDFImpl for StringToArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "string_to_array" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { - List(Arc::new(Field::new("item", arg_types[0].clone(), true))) - } - _ => { - return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." - ); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let mut args = ColumnarValue::values_to_arrays(args)?; - // Case: delimiter is NULL, needs to be handled as well. - if args[1].as_any().is::() { - args[1] = Arc::new(StringArray::new_null(args[1].len())); - }; - - match args[0].data_type() { - Utf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - LargeUtf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - other => { - exec_err!("unsupported type for string_to_array function as {other}") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( Range, range, diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index 9589cb05fe9b..c0f7627d2ab7 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -214,6 +214,18 @@ pub(crate) fn compare_element_to_list( Ok(res) } +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast to {}", + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} +pub(crate) use downcast_arg; + #[cfg(test)] mod tests { use super::*; From 55aacf62b39c7632df6536b2c1bf3856faf708ac Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 15:19:42 -0400 Subject: [PATCH 24/35] Document MSRV policy (#9681) --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index abd727672aca..c3d7c6792990 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,11 @@ Optional features: [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ -## Rust Version Compatibility +## Rust Version Compatibility Policy -Datafusion crate is tested with the [minimum required stable Rust version](https://github.com/search?q=repo%3Aapache%2Farrow-datafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support +each stable Rust version for 6 months after it is +[released](https://github.com/rust-lang/rust/blob/master/RELEASES.md). This +generally translates to support for the most recent 3 to 4 stable Rust versions. + +We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Farrow-datafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) From 496e4b67a05bb49af8d7aa1ca5035312fd4e54f9 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 20 Mar 2024 13:32:15 -0700 Subject: [PATCH 25/35] doc: Add DataFusion profiling documentation for MacOS (#9711) * Add profiling doc for MacOS --- docs/source/index.rst | 3 +- docs/source/library-user-guide/profiling.md | 63 +++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 docs/source/library-user-guide/profiling.md diff --git a/docs/source/index.rst b/docs/source/index.rst index f7c0873f3a5f..919a7ad7036f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -79,7 +79,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for .. toctree:: :maxdepth: 1 :caption: Library User Guide - + library-user-guide/index library-user-guide/using-the-sql-api library-user-guide/working-with-exprs @@ -89,6 +89,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for library-user-guide/adding-udfs library-user-guide/custom-table-providers library-user-guide/extending-operators + library-user-guide/profiling .. _toc.contributor-guide: diff --git a/docs/source/library-user-guide/profiling.md b/docs/source/library-user-guide/profiling.md new file mode 100644 index 000000000000..a20489496f0c --- /dev/null +++ b/docs/source/library-user-guide/profiling.md @@ -0,0 +1,63 @@ + + +# Profiling Cookbook + +The section contains examples how to perform CPU profiling for Apache Arrow DataFusion on different operating systems. + +## MacOS + +### Building a flamegraph + +- [cargo-flamegraph](https://github.com/flamegraph-rs/flamegraph) + +Test: + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --unit-test datafusion -- dataframe::tests::test_array_agg +``` + +Benchmark: + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --bench sql_planner -- --bench +``` + +Open `flamegraph.svg` file with the browser + +- dtrace with DataFusion CLI + +```bash +git clone https://github.com/brendangregg/FlameGraph.git /tmp/fg +cd datafusion-cli +CARGO_PROFILE_RELEASE_DEBUG=true cargo build --release +echo "select * from table;" >> test.sql +sudo dtrace -c './target/debug/datafusion-cli -f test.sql' -o out.stacks -n 'profile-997 /execname == "datafusion-cli"/ { @[ustack(100)] = count(); }' +/tmp/fg/FlameGraph/stackcollapse.pl out.stacks | /tmp/fg/FlameGraph/flamegraph.pl > flamegraph.svg +``` + +Open `flamegraph.svg` file with the browser + +### CPU profiling with XCode Instruments + +[Video: how to CPU profile DataFusion with XCode Instruments](https://youtu.be/P3dXH61Kr5U) + +## Linux + +## Windows From e522bcebb04288fe7fe27192c51dabdf04e6ac88 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 17:13:43 -0400 Subject: [PATCH 26/35] Minor: add ticket reference to commented out test (#9715) --- datafusion/sqllogictest/test_files/copy.slt | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 4d4f596d0c60..7884bece1f39 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -111,6 +111,7 @@ a statement ok create table test ("'test'" varchar, "'test2'" varchar, "'test3'" varchar); +# https://github.com/apache/arrow-datafusion/issues/9714 ## Until the partition by parsing uses ColumnDef, this test is meaningless since it becomes an overfit. Even in ## CREATE EXTERNAL TABLE, there is a schema mismatch, this should be an issue. # From 7a0dd6ff5a78e10a96cb6ee7e1390b2a2df941b2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 17:33:59 -0400 Subject: [PATCH 27/35] Minor: Change path from `common_runtime` to `common-runtime` (#9717) --- Cargo.toml | 4 ++-- datafusion/{common_runtime => common-runtime}/Cargo.toml | 0 datafusion/{common_runtime => common-runtime}/README.md | 0 datafusion/{common_runtime => common-runtime}/src/common.rs | 0 datafusion/{common_runtime => common-runtime}/src/lib.rs | 0 5 files changed, 2 insertions(+), 2 deletions(-) rename datafusion/{common_runtime => common-runtime}/Cargo.toml (100%) rename datafusion/{common_runtime => common-runtime}/README.md (100%) rename datafusion/{common_runtime => common-runtime}/src/common.rs (100%) rename datafusion/{common_runtime => common-runtime}/src/lib.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index d9e69e53db7c..abe6d2c1744b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ exclude = ["datafusion-cli"] members = [ "datafusion/common", - "datafusion/common_runtime", + "datafusion/common-runtime", "datafusion/core", "datafusion/expr", "datafusion/execution", @@ -73,7 +73,7 @@ ctor = "0.2.0" dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common_runtime", version = "36.0.0" } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "36.0.0" } datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } diff --git a/datafusion/common_runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml similarity index 100% rename from datafusion/common_runtime/Cargo.toml rename to datafusion/common-runtime/Cargo.toml diff --git a/datafusion/common_runtime/README.md b/datafusion/common-runtime/README.md similarity index 100% rename from datafusion/common_runtime/README.md rename to datafusion/common-runtime/README.md diff --git a/datafusion/common_runtime/src/common.rs b/datafusion/common-runtime/src/common.rs similarity index 100% rename from datafusion/common_runtime/src/common.rs rename to datafusion/common-runtime/src/common.rs diff --git a/datafusion/common_runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs similarity index 100% rename from datafusion/common_runtime/src/lib.rs rename to datafusion/common-runtime/src/lib.rs From dbfb153658f17448af8e7de7bab0d37f73cdeac1 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 20 Mar 2024 16:01:25 -0600 Subject: [PATCH 28/35] Use object_store:BufWriter to replace put_multipart (#9648) * feat: use BufWriter to replace put_multipart * feat: remove AbortableWrite * fix clippy * fix: add doc comment --- Cargo.toml | 2 +- .../file_format/file_compression_type.rs | 7 +- .../src/datasource/file_format/parquet.rs | 19 ++-- .../src/datasource/file_format/write/mod.rs | 100 ++---------------- .../file_format/write/orchestration.rs | 18 +--- .../core/src/datasource/physical_plan/csv.rs | 10 +- .../core/src/datasource/physical_plan/json.rs | 10 +- .../datasource/physical_plan/parquet/mod.rs | 5 +- 8 files changed, 37 insertions(+), 134 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index abe6d2c1744b..c3dade8bc6c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,7 +93,7 @@ indexmap = "2.0.0" itertools = "0.12" log = "^0.4" num_cpus = "1.13.0" -object_store = { version = "0.9.0", default-features = false } +object_store = { version = "0.9.1", default-features = false } parking_lot = "0.12" parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index c538819e2684..c1fbe352d37b 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -43,6 +43,7 @@ use futures::stream::BoxStream; use futures::StreamExt; #[cfg(feature = "compression")] use futures::TryStreamExt; +use object_store::buffered::BufWriter; use tokio::io::AsyncWrite; #[cfg(feature = "compression")] use tokio_util::io::{ReaderStream, StreamReader}; @@ -148,11 +149,11 @@ impl FileCompressionType { }) } - /// Wrap the given `AsyncWrite` so that it performs compressed writes + /// Wrap the given `BufWriter` so that it performs compressed writes /// according to this `FileCompressionType`. pub fn convert_async_writer( &self, - w: Box, + w: BufWriter, ) -> Result> { Ok(match self.variant { #[cfg(feature = "compression")] @@ -169,7 +170,7 @@ impl FileCompressionType { "Compression feature is not enabled".to_owned(), )) } - UNCOMPRESSED => w, + UNCOMPRESSED => Box::new(w), }) } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index b7626d41f4dd..ec333bb557d2 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -23,7 +23,7 @@ use std::fmt::Debug; use std::sync::Arc; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite, SharedBuffer}; +use super::write::{create_writer, SharedBuffer}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, @@ -56,6 +56,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use object_store::buffered::BufWriter; use parquet::arrow::arrow_writer::{ compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, @@ -613,19 +614,13 @@ impl ParquetSink { location: &Path, object_store: Arc, parquet_props: WriterProperties, - ) -> Result< - AsyncArrowWriter>, - > { - let (_, multipart_writer) = object_store - .put_multipart(location) - .await - .map_err(DataFusionError::ObjectStore)?; + ) -> Result> { + let buf_writer = BufWriter::new(object_store, location.clone()); let writer = AsyncArrowWriter::try_new( - multipart_writer, + buf_writer, self.get_writer_schema(), Some(parquet_props), )?; - Ok(writer) } @@ -943,7 +938,7 @@ async fn concatenate_parallel_row_groups( mut serialize_rx: Receiver>, schema: Arc, writer_props: Arc, - mut object_store_writer: AbortableWrite>, + mut object_store_writer: Box, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); @@ -985,7 +980,7 @@ async fn concatenate_parallel_row_groups( /// task then stitches these independent RowGroups together and streams this large /// single parquet file to an ObjectStore in multiple parts. async fn output_single_parquet_file_parallelized( - object_store_writer: AbortableWrite>, + object_store_writer: Box, data: Receiver, output_schema: Arc, parquet_props: &WriterProperties, diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 410a32a19cc1..42115fc7b93f 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -18,21 +18,18 @@ //! Module containing helper methods/traits related to enabling //! write support for the various file formats -use std::io::{Error, Write}; -use std::pin::Pin; +use std::io::Write; use std::sync::Arc; -use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::error::Result; use arrow_array::RecordBatch; -use datafusion_common::DataFusionError; use bytes::Bytes; -use futures::future::BoxFuture; +use object_store::buffered::BufWriter; use object_store::path::Path; -use object_store::{MultipartId, ObjectStore}; +use object_store::ObjectStore; use tokio::io::AsyncWrite; pub(crate) mod demux; @@ -69,79 +66,6 @@ impl Write for SharedBuffer { } } -/// Stores data needed during abortion of MultiPart writers -#[derive(Clone)] -pub(crate) struct MultiPart { - /// A shared reference to the object store - store: Arc, - multipart_id: MultipartId, - location: Path, -} - -impl MultiPart { - /// Create a new `MultiPart` - pub fn new( - store: Arc, - multipart_id: MultipartId, - location: Path, - ) -> Self { - Self { - store, - multipart_id, - location, - } - } -} - -/// A wrapper struct with abort method and writer -pub(crate) struct AbortableWrite { - writer: W, - multipart: MultiPart, -} - -impl AbortableWrite { - /// Create a new `AbortableWrite` instance with the given writer, and write mode. - pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { - Self { writer, multipart } - } - - /// handling of abort for different write modes - pub(crate) fn abort_writer(&self) -> Result>> { - let multi = self.multipart.clone(); - Ok(Box::pin(async move { - multi - .store - .abort_multipart(&multi.location, &multi.multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } -} - -impl AsyncWrite for AbortableWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) - } -} - /// A trait that defines the methods required for a RecordBatch serializer. pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. @@ -150,19 +74,15 @@ pub trait BatchSerializer: Sync + Send { fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } -/// Returns an [`AbortableWrite`] which writes to the given object store location -/// with the specified compression +/// Returns an [`AsyncWrite`] which writes to the given object store location +/// with the specified compression. +/// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. +/// Users can configure automatic cleanup with their cloud provider. pub(crate) async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, -) -> Result>> { - let (multipart_id, writer) = object_store - .put_multipart(location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - MultiPart::new(object_store, multipart_id, location.clone()), - )) +) -> Result> { + let buf_writer = BufWriter::new(object_store, location.clone()); + file_compression_type.convert_async_writer(buf_writer) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index b7f268959311..3ae2122de827 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer}; +use super::{create_writer, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; @@ -39,7 +39,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinSet; -type WriterType = AbortableWrite>; +type WriterType = Box; type SerializerType = Arc; /// Serializes a single data stream in parallel and writes to an ObjectStore @@ -49,7 +49,7 @@ type SerializerType = Arc; pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, - mut writer: AbortableWrite>, + mut writer: WriterType, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); @@ -173,19 +173,9 @@ pub(crate) async fn stateless_serialize_and_write_files( // Finalize or abort writers as appropriate for mut writer in finished_writers.into_iter() { - match any_errors { - true => { - let abort_result = writer.abort_writer(); - if abort_result.is_err() { - any_abort_errors = true; - } - } - false => { - writer.shutdown() + writer.shutdown() .await .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; - } - } } if any_errors { diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 5fcb9f483952..31cc52f79697 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -44,6 +44,7 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, StreamExt, TryStreamExt}; +use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -471,7 +472,7 @@ pub async fn plan_to_csv( let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); //only write headers on first iteration let mut write_headers = true; @@ -481,15 +482,12 @@ pub async fn plan_to_csv( .build(buffer); writer.write(&batch)?; buffer = writer.into_inner(); - multipart_writer.write_all(&buffer).await?; + buf_writer.write_all(&buffer).await?; buffer.clear(); //prevent writing headers more than once write_headers = false; } - multipart_writer - .shutdown() - .await - .map_err(DataFusionError::from) + buf_writer.shutdown().await.map_err(DataFusionError::from) }); } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 068426e0fdcb..194a4a91c34a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -43,6 +43,7 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, StreamExt, TryStreamExt}; +use object_store::buffered::BufWriter; use object_store::{self, GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -338,21 +339,18 @@ pub async fn plan_to_json( let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); while let Some(batch) = stream.next().await.transpose()? { let mut writer = json::LineDelimitedWriter::new(buffer); writer.write(&batch)?; buffer = writer.into_inner(); - multipart_writer.write_all(&buffer).await?; + buf_writer.write_all(&buffer).await?; buffer.clear(); } - multipart_writer - .shutdown() - .await - .map_err(DataFusionError::from) + buf_writer.shutdown().await.map_err(DataFusionError::from) }); } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 282cd624d036..767cde9cc55e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -52,6 +52,7 @@ use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use log::debug; +use object_store::buffered::BufWriter; use object_store::path::Path; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; @@ -698,11 +699,11 @@ pub async fn plan_to_parquet( let propclone = writer_properties.clone(); let storeref = store.clone(); - let (_, multipart_writer) = storeref.put_multipart(&file).await?; + let buf_writer = BufWriter::new(storeref, file.clone()); let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { let mut writer = - AsyncArrowWriter::try_new(multipart_writer, plan.schema(), propclone)?; + AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; while let Some(next_batch) = stream.next().await { let batch = next_batch?; writer.write(&batch).await?; From 14972e6ae4be799450d1fbb81073fa0e1cbe57bc Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Thu, 21 Mar 2024 05:24:04 +0530 Subject: [PATCH 29/35] Fix COPY TO failing on passing format options through CLI (#9709) * Fix COPY TO failing on passing format options through CLI * fix clippy lint error --- datafusion-cli/src/exec.rs | 20 +++++++++++++++++-- .../common/src/file_options/file_type.rs | 14 +++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index ea765ee8eceb..4e374a4c0032 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -40,6 +40,7 @@ use datafusion::prelude::SessionContext; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion_common::FileType; use rustyline::error::ReadlineError; use rustyline::Editor; use tokio::signal; @@ -257,15 +258,23 @@ async fn create_plan( // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { - register_object_store_and_config_extensions(ctx, &cmd.location, &cmd.options) - .await?; + register_object_store_and_config_extensions( + ctx, + &cmd.location, + &cmd.options, + None, + ) + .await?; } if let LogicalPlan::Copy(copy_to) = &mut plan { + let format: FileType = (©_to.format_options).into(); + register_object_store_and_config_extensions( ctx, ©_to.output_url, ©_to.options, + Some(format), ) .await?; } @@ -303,6 +312,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx: &SessionContext, location: &String, options: &HashMap, + format: Option, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -318,6 +328,9 @@ pub(crate) async fn register_object_store_and_config_extensions( // Clone and modify the default table options based on the provided options let mut table_options = ctx.state().default_table_options().clone(); + if let Some(format) = format { + table_options.set_file_format(format); + } table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options @@ -347,6 +360,7 @@ mod tests { &ctx, &cmd.location, &cmd.options, + None, ) .await?; } else { @@ -367,10 +381,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Copy(cmd) = &plan { + let format: FileType = (&cmd.format_options).into(); register_object_store_and_config_extensions( &ctx, &cmd.output_url, &cmd.options, + Some(format), ) .await?; } else { diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index 812cb02a5f77..fc0bb7445645 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -20,6 +20,7 @@ use std::fmt::{self, Display}; use std::str::FromStr; +use crate::config::FormatOptions; use crate::error::{DataFusionError, Result}; /// The default file extension of arrow files @@ -55,6 +56,19 @@ pub enum FileType { JSON, } +impl From<&FormatOptions> for FileType { + fn from(value: &FormatOptions) -> Self { + match value { + FormatOptions::CSV(_) => FileType::CSV, + FormatOptions::JSON(_) => FileType::JSON, + #[cfg(feature = "parquet")] + FormatOptions::PARQUET(_) => FileType::PARQUET, + FormatOptions::AVRO => FileType::AVRO, + FormatOptions::ARROW => FileType::ARROW, + } + } +} + impl GetExt for FileType { fn get_ext(&self) -> String { match self { From b72d25cc3a3a4257de1fc88e8df56b4c874d60ce Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 21 Mar 2024 07:56:54 +0800 Subject: [PATCH 30/35] fix: recursive cte hangs on joins (#9687) * fix: recursive cte hangs on joins * Use ExecutionPlan::with_new_children * Naming --- .../physical-plan/src/recursive_query.rs | 26 ++++++- datafusion/sqllogictest/test_files/cte.slt | 73 +++++++++++++++++-- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 68abc9653a8b..140820ff782a 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -309,10 +309,9 @@ impl RecursiveQueryStream { // Downstream plans should not expect any partitioning. let partition = 0; - self.recursive_stream = Some( - self.recursive_term - .execute(partition, self.task_context.clone())?, - ); + let recursive_plan = reset_plan_states(self.recursive_term.clone())?; + self.recursive_stream = + Some(recursive_plan.execute(partition, self.task_context.clone())?); self.poll_next(cx) } } @@ -343,6 +342,25 @@ fn assign_work_table( .data() } +/// Some plans will change their internal states after execution, making them unable to be executed again. +/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states. +/// +/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. +/// However, if the data of the left table is derived from the work table, it will become outdated +/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. +fn reset_plan_states(plan: Arc) -> Result> { + plan.transform_up(&|plan| { + // WorkTableExec's states have already been updated correctly. + if plan.as_any().is::() { + Ok(Transformed::no(plan)) + } else { + let new_plan = plan.clone().with_new_children(plan.children())?; + Ok(Transformed::yes(new_plan)) + } + }) + .data() +} + impl Stream for RecursiveQueryStream { type Item = Result; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 6b9db5589391..50c88e41959f 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -40,11 +40,6 @@ ProjectionExec: expr=[1 as a, 2 as b, 3 as c] --PlaceholderRowExec - -# enable recursive CTEs -statement ok -set datafusion.execution.enable_recursive_ctes = true; - # trivial recursive CTE works query I rowsort WITH RECURSIVE nodes AS ( @@ -651,3 +646,71 @@ WITH RECURSIVE my_cte AS ( WHERE my_cte.a<5 ) SELECT a FROM my_cte; + + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9680 +query I +WITH RECURSIVE recursive_cte AS ( + SELECT 1 as val + UNION ALL + ( + WITH sub_cte AS ( + SELECT 2 as val + ) + SELECT + 2 as val + FROM recursive_cte + CROSS JOIN sub_cte + WHERE recursive_cte.val < 2 + ) +) +SELECT * FROM recursive_cte; +---- +1 +2 + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9680 +# 'recursive_cte' should be on the left of the cross join, as this is the test purpose of the above query. +query TT +explain WITH RECURSIVE recursive_cte AS ( + SELECT 1 as val + UNION ALL + ( + WITH sub_cte AS ( + SELECT 2 as val + ) + SELECT + 2 as val + FROM recursive_cte + CROSS JOIN sub_cte + WHERE recursive_cte.val < 2 + ) +) +SELECT * FROM recursive_cte; +---- +logical_plan +Projection: recursive_cte.val +--SubqueryAlias: recursive_cte +----RecursiveQuery: is_distinct=false +------Projection: Int64(1) AS val +--------EmptyRelation +------Projection: Int64(2) AS val +--------CrossJoin: +----------Filter: recursive_cte.val < Int64(2) +------------TableScan: recursive_cte +----------SubqueryAlias: sub_cte +------------Projection: Int64(2) AS val +--------------EmptyRelation +physical_plan +RecursiveQueryExec: name=recursive_cte, is_distinct=false +--ProjectionExec: expr=[1 as val] +----PlaceholderRowExec +--ProjectionExec: expr=[2 as val] +----CrossJoinExec +------CoalescePartitionsExec +--------CoalesceBatchesExec: target_batch_size=8182 +----------FilterExec: val@0 < 2 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------WorkTableExec: name=recursive_cte +------ProjectionExec: expr=[2 as val] +--------PlaceholderRowExec From 1d8a41bc8e08b56e90d6f8e6ef20e39a126987e4 Mon Sep 17 00:00:00 2001 From: "Reilly.tang" Date: Thu, 21 Mar 2024 07:57:05 +0800 Subject: [PATCH 31/35] Move `starts_with`, `to_hex`,` trim`, `upper` to datafusion-functions (and add string_expressions) (#9541) * [task #9539] Move starts_with, to_hex, trim, upper to datafusion-functions Signed-off-by: tangruilin * Export expr_fn, restore tests * fix comments --------- Signed-off-by: tangruilin Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 57 +--- datafusion/expr/src/expr_fn.rs | 18 -- datafusion/functions/Cargo.toml | 3 + datafusion/functions/src/lib.rs | 9 +- datafusion/functions/src/string/mod.rs | 292 ++++++++++++++++++ .../functions/src/string/starts_with.rs | 89 ++++++ datafusion/functions/src/string/to_hex.rs | 155 ++++++++++ datafusion/functions/src/string/trim.rs | 78 +++++ datafusion/functions/src/string/upper.rs | 66 ++++ datafusion/physical-expr/src/functions.rs | 118 ------- .../physical-expr/src/string_expressions.rs | 77 +---- datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 12 - datafusion/proto/src/generated/prost.rs | 16 +- .../proto/src/logical_plan/from_proto.rs | 22 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 - datafusion/sql/src/expr/mod.rs | 2 +- 17 files changed, 720 insertions(+), 306 deletions(-) create mode 100644 datafusion/functions/src/string/mod.rs create mode 100644 datafusion/functions/src/string/starts_with.rs create mode 100644 datafusion/functions/src/string/to_hex.rs create mode 100644 datafusion/functions/src/string/trim.rs create mode 100644 datafusion/functions/src/string/upper.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 79cd6a24ce39..fffe2cf4c9c9 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -147,20 +147,12 @@ pub enum BuiltinScalarFunction { Rtrim, /// split_part SplitPart, - /// starts_with - StartsWith, /// strpos Strpos, /// substr Substr, - /// to_hex - ToHex, /// translate Translate, - /// trim - Trim, - /// upper - Upper, /// uuid Uuid, /// overlay @@ -276,13 +268,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Rpad => Volatility::Immutable, BuiltinScalarFunction::Rtrim => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, - BuiltinScalarFunction::StartsWith => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, - BuiltinScalarFunction::ToHex => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::Trim => Volatility::Immutable, - BuiltinScalarFunction::Upper => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, BuiltinScalarFunction::Levenshtein => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, @@ -365,7 +353,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } - BuiltinScalarFunction::StartsWith => Ok(Boolean), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") @@ -373,12 +360,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Substr => { utf8_to_str_type(&input_expr_types[0], "substr") } - BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, - _ => { - return plan_err!("The to_hex function can only accept integers."); - } - }), BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } @@ -388,10 +369,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Translate => { utf8_to_str_type(&input_expr_types[0], "translate") } - BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), - BuiltinScalarFunction::Upper => { - utf8_to_str_type(&input_expr_types[0], "upper") - } BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd @@ -476,18 +453,16 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Lower | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => { + | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } BuiltinScalarFunction::Btrim | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim - | BuiltinScalarFunction::Trim => Signature::one_of( + | BuiltinScalarFunction::Rtrim => Signature::one_of( vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], self.volatility(), ), - BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { + BuiltinScalarFunction::Chr => { Signature::uniform(1, vec![Int64], self.volatility()) } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { @@ -519,17 +494,17 @@ impl BuiltinScalarFunction { self.volatility(), ), - BuiltinScalarFunction::EndsWith - | BuiltinScalarFunction::Strpos - | BuiltinScalarFunction::StartsWith => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ), + BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ) + } BuiltinScalarFunction::Substr => Signature::one_of( vec![ @@ -749,13 +724,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Rpad => &["rpad"], BuiltinScalarFunction::Rtrim => &["rtrim"], BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StartsWith => &["starts_with"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], BuiltinScalarFunction::Levenshtein => &["levenshtein"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b76164a1c83c..8667f631c507 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -575,12 +575,6 @@ scalar_expr!(Log10, log10, num, "base 10 logarithm of number"); scalar_expr!(Ln, ln, num, "natural logarithm (base e) of number"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); -scalar_expr!( - ToHex, - to_hex, - num, - "returns the hexdecimal representation of an integer" -); scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); @@ -630,19 +624,11 @@ scalar_expr!( "removes all characters, spaces by default, from the end of a string" ); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); -scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); -scalar_expr!( - Trim, - trim, - string, - "removes all characters, space by default from the string" -); -scalar_expr!(Upper, upper, string, "converts the string to upper case"); //use vec as parameter nary_scalar_expr!( Lpad, @@ -1117,15 +1103,11 @@ mod test { test_nary_scalar_expr!(Rpad, rpad, string, count, characters); test_scalar_expr!(Rtrim, rtrim, string); test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); - test_scalar_expr!(StartsWith, starts_with, string, characters); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); test_scalar_expr!(Substr, substring, string, position, count); - test_scalar_expr!(ToHex, to_hex, string); test_scalar_expr!(Translate, translate, string, from, to); - test_scalar_expr!(Trim, trim, string); - test_scalar_expr!(Upper, upper, string); test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 5a6da5345d7c..b12c99e84a90 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -29,6 +29,8 @@ authors = { workspace = true } rust-version = { workspace = true } [features] +# enable string functions +string_expressions = [] # enable core functions core_expressions = [] # enable datetime functions @@ -41,6 +43,7 @@ default = [ "math_expressions", "regex_expressions", "crypto_expressions", + "string_expressions", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 3a2eab8e5f05..f469b343e144 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -84,6 +84,10 @@ use log::debug; #[macro_use] pub mod macros; +#[cfg(feature = "string_expressions")] +pub mod string; +make_stub_package!(string, "string_expressions"); + /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] @@ -134,6 +138,8 @@ pub mod expr_fn { pub use super::math::expr_fn::*; #[cfg(feature = "regex_expressions")] pub use super::regex::expr_fn::*; + #[cfg(feature = "string_expressions")] + pub use super::string::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -144,7 +150,8 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(encoding::functions()) .chain(math::functions()) .chain(regex::functions()) - .chain(crypto::functions()); + .chain(crypto::functions()) + .chain(string::functions()); all_functions.try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs new file mode 100644 index 000000000000..08fcbb363bbc --- /dev/null +++ b/datafusion/functions/src/string/mod.rs @@ -0,0 +1,292 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}, + datatypes::DataType, +}; +use datafusion_common::{ + cast::as_generic_string_array, exec_err, plan_err, Result, ScalarValue, +}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return plan_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return plan_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +/// applies a unary expression to `args[0]` that is expected to be downcastable to +/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` +pub(crate) fn unary_string_function<'a, T, O, F, R>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef, + O: OffsetSizeTrait, + T: OffsetSizeTrait, + F: Fn(&'a str) -> R, +{ + if args.len() != 1 { + return exec_err!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name + ); + } + + let string_array = as_generic_string_array::(args[0])?; + + // first map is the iterator, second is for the `Option<_>` + Ok(string_array.iter().map(|string| string.map(&op)).collect()) +} + +fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result +where + R: AsRef, + F: Fn(&'a str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i32, + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i64, + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} + +// TODO: mode allow[(dead_code)] after move ltrim and rtrim +enum TrimType { + #[allow(dead_code)] + Left, + #[allow(dead_code)] + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + + match args.len() { + 1 => { + let result = string_array + .iter() + .map(|string| string.map(|string: &str| func(string, " "))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let characters_array = as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (Some(string), Some(characters)) => Some(func(string, characters)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } + } +} + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} + +mod starts_with; +mod to_hex; +mod trim; +mod upper; +// create UDFs +make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); +make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); +make_udf_function!(trim::TrimFunc, TRIM, trim); +make_udf_function!(upper::UpperFunc, UPPER, upper); + +export_functions!( + ( + starts_with, + arg1 arg2, + "Returns true if string starts with prefix."), + ( + to_hex, + arg1, + "Converts an integer to a hexadecimal string."), + (trim, + arg1, + "removes all characters, space by default from the string"), + (upper, + arg1, + "Converts a string to uppercase.")); diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs new file mode 100644 index 000000000000..1fce399d1e70 --- /dev/null +++ b/datafusion/functions/src/string/starts_with.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use crate::string::make_scalar_function; + +/// Returns true if string starts with prefix. +/// starts_with('alphabet', 'alph') = 't' +pub fn starts_with(args: &[ArrayRef]) -> Result { + let left = as_generic_string_array::(&args[0])?; + let right = as_generic_string_array::(&args[1])?; + + let result = arrow::compute::kernels::comparison::starts_with(left, right)?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct StartsWithFunc { + signature: Signature, +} +impl StartsWithFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StartsWithFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "starts_with" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(starts_with::, vec![])(args), + DataType::LargeUtf8 => { + return make_scalar_function(starts_with::, vec![])(args); + } + _ => internal_err!("Unsupported data type"), + } + } +} diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs new file mode 100644 index 000000000000..4dfc84887da2 --- /dev/null +++ b/datafusion/functions/src/string/to_hex.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, +}; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use super::make_scalar_function; + +/// Converts the number to its equivalent hexadecimal representation. +/// to_hex(2147483647) = '7fffffff' +pub fn to_hex(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let integer_array = as_primitive_array::(&args[0])?; + + let result = integer_array + .iter() + .map(|integer| { + if let Some(value) = integer { + if let Some(value_usize) = value.to_usize() { + Ok(Some(format!("{value_usize:x}"))) + } else if let Some(value_isize) = value.to_isize() { + Ok(Some(format!("{value_isize:x}"))) + } else { + exec_err!("Unsupported data type {integer:?} for function to_hex") + } + } else { + Ok(None) + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct ToHexFunc { + signature: Signature, +} +impl ToHexFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ToHexFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_hex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Int8 | Int16 | Int32 | Int64 => Utf8, + _ => { + return plan_err!("The to_hex function can only accept integers."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Int32 => make_scalar_function(to_hex::, vec![])(args), + DataType::Int64 => make_scalar_function(to_hex::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function to_hex"), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::Int32Type, + }; + + use datafusion_common::cast::as_string_array; + + use super::*; + + #[test] + // Test to_hex function for zero + fn to_hex_zero() -> Result<()> { + let array = vec![0].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("0")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } + + #[test] + // Test to_hex function for positive number + fn to_hex_positive_number() -> Result<()> { + let array = vec![100].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("64")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } + + #[test] + // Test to_hex function for negative number + fn to_hex_negative_number() -> Result<()> { + let array = vec![-1].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("ffffffffffffffff")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/trim.rs b/datafusion/functions/src/string/trim.rs new file mode 100644 index 000000000000..e04a171722e3 --- /dev/null +++ b/datafusion/functions/src/string/trim.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +use crate::string::{make_scalar_function, utf8_to_str_type}; + +use super::{general_trim, TrimType}; + +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +#[derive(Debug)] +pub(super) struct TrimFunc { + signature: Signature, +} + +impl TrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "trim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "trim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(btrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(btrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function trim"), + } + } +} diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs new file mode 100644 index 000000000000..ed41487699aa --- /dev/null +++ b/datafusion/functions/src/string/upper.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +use crate::string::utf8_to_str_type; + +use super::handle; + +#[derive(Debug)] +pub(super) struct UpperFunc { + signature: Signature, +} + +impl UpperFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for UpperFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "upper") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + handle(args, |string| string.to_uppercase(), "upper") + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e76e7f56dc95..f2c93c3ec1dd 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -447,17 +447,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function split_part") } }), - BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::starts_with::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::starts_with::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function starts_with") - } - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -497,15 +486,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function substr"), }), - BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { - DataType::Int32 => { - make_scalar_function_inner(string_expressions::to_hex::)(args) - } - DataType::Int64 => { - make_scalar_function_inner(string_expressions::to_hex::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function to_hex"), - }), BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -527,16 +507,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function translate") } }), - BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function trim"), - }), - BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -1797,38 +1767,6 @@ mod tests { Utf8, StringArray ); - test_function!( - StartsWith, - &[lit("alphabet"), lit("alph"),], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit("blph"),], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit(ScalarValue::Utf8(None)), lit("alph"),], - Ok(None), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit(ScalarValue::Utf8(None)),], - Ok(None), - bool, - Boolean, - BooleanArray - ); test_function!( EndsWith, &[lit("alphabet"), lit("alph"),], @@ -2149,62 +2087,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Trim, - &[lit(" trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit("upper")], - Ok(Some("UPPER")), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit("UPPER")], - Ok(Some("UPPER")), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); Ok(()) } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index ace7ef2888a3..86c0092a220d 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -32,16 +32,14 @@ use arrow::{ Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, StringArray, }, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, + datatypes::DataType, }; use uuid::Uuid; use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use datafusion_common::{ - cast::{ - as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, - }, + cast::{as_generic_string_array, as_int64_array, as_string_array}, exec_err, ScalarValue, }; use datafusion_expr::ColumnarValue; @@ -526,34 +524,6 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Converts the number to its equivalent hexadecimal representation. -/// to_hex(2147483647) = '7fffffff' -pub fn to_hex(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let integer_array = as_primitive_array::(&args[0])?; - - let result = integer_array - .iter() - .map(|integer| { - if let Some(value) = integer { - if let Some(value_usize) = value.to_usize() { - Ok(Some(format!("{value_usize:x}"))) - } else if let Some(value_isize) = value.to_isize() { - Ok(Some(format!("{value_isize:x}"))) - } else { - exec_err!("Unsupported data type {integer:?} for function to_hex") - } - } else { - Ok(None) - } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Converts the string to all upper case. /// upper('tom') = 'TOM' pub fn upper(args: &[ColumnarValue]) -> Result { @@ -709,54 +679,13 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { - use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow::array::Int32Array; use arrow_array::Int64Array; use datafusion_common::cast::as_int32_array; - use crate::string_expressions; - use super::*; - #[test] - // Test to_hex function for zero - fn to_hex_zero() -> Result<()> { - let array = vec![0].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("0")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - // Test to_hex function for positive number - fn to_hex_positive_number() -> Result<()> { - let array = vec![100].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("64")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - // Test to_hex function for negative number - fn to_hex_negative_number() -> Result<()> { - let array = vec![-1].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("ffffffffffffffff")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - #[test] fn to_overlay() -> Result<()> { let string = diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 10f79a2b8cc8..c009682d5a4d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -592,18 +592,18 @@ enum ScalarFunction { // 48 was SHA384 // 49 was SHA512 SplitPart = 50; - StartsWith = 51; + // StartsWith = 51; Strpos = 52; Substr = 53; - ToHex = 54; + // ToHex = 54; // 55 was ToTimestamp // 56 was ToTimestampMillis // 57 was ToTimestampMicros // 58 was ToTimestampSeconds // 59 was Now Translate = 60; - Trim = 61; - Upper = 62; + // Trim = 61; + // Upper = 62; Coalesce = 63; Power = 64; // 65 was StructFun diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7757a64ef359..58683dba6dff 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22949,13 +22949,9 @@ impl serde::Serialize for ScalarFunction { Self::Rpad => "Rpad", Self::Rtrim => "Rtrim", Self::SplitPart => "SplitPart", - Self::StartsWith => "StartsWith", Self::Strpos => "Strpos", Self::Substr => "Substr", - Self::ToHex => "ToHex", Self::Translate => "Translate", - Self::Trim => "Trim", - Self::Upper => "Upper", Self::Coalesce => "Coalesce", Self::Power => "Power", Self::Atan2 => "Atan2", @@ -23027,13 +23023,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Rpad", "Rtrim", "SplitPart", - "StartsWith", "Strpos", "Substr", - "ToHex", "Translate", - "Trim", - "Upper", "Coalesce", "Power", "Atan2", @@ -23134,13 +23126,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Rpad" => Ok(ScalarFunction::Rpad), "Rtrim" => Ok(ScalarFunction::Rtrim), "SplitPart" => Ok(ScalarFunction::SplitPart), - "StartsWith" => Ok(ScalarFunction::StartsWith), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), - "ToHex" => Ok(ScalarFunction::ToHex), "Translate" => Ok(ScalarFunction::Translate), - "Trim" => Ok(ScalarFunction::Trim), - "Upper" => Ok(ScalarFunction::Upper), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), "Atan2" => Ok(ScalarFunction::Atan2), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ab0ddb14ebfc..8eabb3b18603 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2891,18 +2891,18 @@ pub enum ScalarFunction { /// 48 was SHA384 /// 49 was SHA512 SplitPart = 50, - StartsWith = 51, + /// StartsWith = 51; Strpos = 52, Substr = 53, - ToHex = 54, + /// ToHex = 54; /// 55 was ToTimestamp /// 56 was ToTimestampMillis /// 57 was ToTimestampMicros /// 58 was ToTimestampSeconds /// 59 was Now Translate = 60, - Trim = 61, - Upper = 62, + /// Trim = 61; + /// Upper = 62; Coalesce = 63, Power = 64, /// 65 was StructFun @@ -3022,13 +3022,9 @@ impl ScalarFunction { ScalarFunction::Rpad => "Rpad", ScalarFunction::Rtrim => "Rtrim", ScalarFunction::SplitPart => "SplitPart", - ScalarFunction::StartsWith => "StartsWith", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", - ScalarFunction::ToHex => "ToHex", ScalarFunction::Translate => "Translate", - ScalarFunction::Trim => "Trim", - ScalarFunction::Upper => "Upper", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", ScalarFunction::Atan2 => "Atan2", @@ -3094,13 +3090,9 @@ impl ScalarFunction { "Rpad" => Some(Self::Rpad), "Rtrim" => Some(Self::Rtrim), "SplitPart" => Some(Self::SplitPart), - "StartsWith" => Some(Self::StartsWith), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), - "ToHex" => Some(Self::ToHex), "Translate" => Some(Self::Translate), - "Trim" => Some(Self::Trim), - "Upper" => Some(Self::Upper), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), "Atan2" => Some(Self::Atan2), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 8581156e2bb8..64ceb37d2961 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -57,10 +57,9 @@ use datafusion_expr::{ logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, nanvl, octet_length, overlay, pi, power, radians, random, repeat, replace, reverse, right, round, rpad, rtrim, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substr_index, substring, to_hex, translate, trim, trunc, - upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + strpos, substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, + Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -462,8 +461,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::OctetLength => Self::OctetLength, ScalarFunction::Concat => Self::Concat, ScalarFunction::Lower => Self::Lower, - ScalarFunction::Upper => Self::Upper, - ScalarFunction::Trim => Self::Trim, ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::Log2 => Self::Log2, @@ -485,10 +482,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Right => Self::Right, ScalarFunction::Rpad => Self::Rpad, ScalarFunction::SplitPart => Self::SplitPart, - ScalarFunction::StartsWith => Self::StartsWith, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, - ScalarFunction::ToHex => Self::ToHex, ScalarFunction::Uuid => Self::Uuid, ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, @@ -1444,10 +1439,6 @@ pub fn parse_expr( ScalarFunction::Lower => { Ok(lower(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Upper => { - Ok(upper(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Ltrim => { Ok(ltrim(parse_expr(&args[0], registry, codec)?)) } @@ -1532,10 +1523,6 @@ pub fn parse_expr( parse_expr(&args[1], registry, codec)?, parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, @@ -1563,9 +1550,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ToHex => { - Ok(to_hex(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Translate => Ok(translate( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 05a29ff6d42b..89bd93550a04 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,8 +1481,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::OctetLength => Self::OctetLength, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Lower => Self::Lower, - BuiltinScalarFunction::Upper => Self::Upper, - BuiltinScalarFunction::Trim => Self::Trim, BuiltinScalarFunction::Ltrim => Self::Ltrim, BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::Log2 => Self::Log2, @@ -1505,10 +1503,8 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Right => Self::Right, BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::SplitPart => Self::SplitPart, - BuiltinScalarFunction::StartsWith => Self::StartsWith, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, - BuiltinScalarFunction::ToHex => Self::ToHex, BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 5e9c0623a265..c34b42193cec 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -747,7 +747,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(TrimWhereField::Leading) => BuiltinScalarFunction::Ltrim, Some(TrimWhereField::Trailing) => BuiltinScalarFunction::Rtrim, Some(TrimWhereField::Both) => BuiltinScalarFunction::Btrim, - None => BuiltinScalarFunction::Trim, + None => BuiltinScalarFunction::Btrim, }; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; From dc373a3550610ce041fd73a1eabe08b096d6ed27 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Fri, 22 Mar 2024 01:13:44 +1100 Subject: [PATCH 32/35] Support for `extract(x from time)` / `date_part` from time types (#8693) * Initial support for `extract(x from time)` * Update function docs * Add extract tests --- datafusion/common/src/cast.rs | 37 ++- .../functions/src/datetime/date_part.rs | 31 +- datafusion/sqllogictest/test_files/expr.slt | 287 ++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 27 +- 4 files changed, 345 insertions(+), 37 deletions(-) diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 088f03e002ed..0dc0532bbb6f 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -24,17 +24,18 @@ use crate::{downcast_value, DataFusionError, Result}; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, - DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, - Float64Array, GenericBinaryArray, GenericListArray, GenericStringArray, - Int32Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeListArray, ListArray, MapArray, NullArray, - OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt32Array, UInt64Array, UInt8Array, UnionArray, + Decimal256Array, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, + Float32Array, Float64Array, GenericBinaryArray, GenericListArray, + GenericStringArray, Int32Array, Int64Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray, ListArray, + MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt32Array, UInt64Array, + UInt8Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; -use arrow_array::Decimal256Array; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -154,6 +155,26 @@ pub fn as_union_array(array: &dyn Array) -> Result<&UnionArray> { Ok(downcast_value!(array, UnionArray)) } +// Downcast ArrayRef to Time32SecondArray +pub fn as_time32_second_array(array: &dyn Array) -> Result<&Time32SecondArray> { + Ok(downcast_value!(array, Time32SecondArray)) +} + +// Downcast ArrayRef to Time32MillisecondArray +pub fn as_time32_millisecond_array(array: &dyn Array) -> Result<&Time32MillisecondArray> { + Ok(downcast_value!(array, Time32MillisecondArray)) +} + +// Downcast ArrayRef to Time64MicrosecondArray +pub fn as_time64_microsecond_array(array: &dyn Array) -> Result<&Time64MicrosecondArray> { + Ok(downcast_value!(array, Time64MicrosecondArray)) +} + +// Downcast ArrayRef to Time64NanosecondArray +pub fn as_time64_nanosecond_array(array: &dyn Array) -> Result<&Time64NanosecondArray> { + Ok(downcast_value!(array, Time64NanosecondArray)) +} + // Downcast ArrayRef to TimestampNanosecondArray pub fn as_timestamp_nanosecond_array( array: &dyn Array, diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 5d2719bf0365..b41f7e13cff2 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -20,14 +20,17 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Float64Array}; use arrow::compute::{binary, cast, date_part, DatePart}; -use arrow::datatypes::DataType::{Date32, Date64, Float64, Timestamp, Utf8}; +use arrow::datatypes::DataType::{ + Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, +}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_int32_array, as_timestamp_microsecond_array, - as_timestamp_millisecond_array, as_timestamp_nanosecond_array, - as_timestamp_second_array, + as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, + as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, }; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; @@ -68,6 +71,10 @@ impl DatePartFunc { ]), Exact(vec![Utf8, Date64]), Exact(vec![Utf8, Date32]), + Exact(vec![Utf8, Time32(Second)]), + Exact(vec![Utf8, Time32(Millisecond)]), + Exact(vec![Utf8, Time64(Microsecond)]), + Exact(vec![Utf8, Time64(Nanosecond)]), ], Volatility::Immutable, ), @@ -149,12 +156,9 @@ fn date_part_f64(array: &dyn Array, part: DatePart) -> Result { Ok(cast(date_part(array, part)?.as_ref(), &Float64)?) } -/// invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the /// result to a total number of seconds, milliseconds, microseconds or /// nanoseconds -/// -/// # Panics -/// If `array` is not a temporal type such as Timestamp or Date32 fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { let sf = match unit { Second => 1_f64, @@ -163,6 +167,7 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { Nanosecond => 1_000_000_000_f64, }; let secs = date_part(array, DatePart::Second)?; + // This assumes array is primitive and not a dictionary let secs = as_int32_array(secs.as_ref())?; let subsecs = date_part(array, DatePart::Nanosecond)?; let subsecs = as_int32_array(subsecs.as_ref())?; @@ -189,6 +194,16 @@ fn epoch(array: &dyn Array) -> Result { } Date32 => as_date32_array(array)?.unary(|x| x as f64 * SECONDS_IN_A_DAY), Date64 => as_date64_array(array)?.unary(|x| x as f64 / 1_000_f64), + Time32(Second) => as_time32_second_array(array)?.unary(|x| x as f64), + Time32(Millisecond) => { + as_time32_millisecond_array(array)?.unary(|x| x as f64 / 1_000_f64) + } + Time64(Microsecond) => { + as_time64_microsecond_array(array)?.unary(|x| x as f64 / 1_000_000_f64) + } + Time64(Nanosecond) => { + as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) + } d => return exec_err!("Can not convert {d:?} to epoch"), }; Ok(Arc::new(f)) diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 73fb5eec97d5..d6343f9a3fe8 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -939,6 +939,293 @@ SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') ---- 12123456780 +# test_date_part_time + +## time32 seconds +query R +SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query R +SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +query R +SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +## time32 milliseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50.123 + +query R +SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50.123 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +## time64 microseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50.123456 + +query R +SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50.123456 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123.456 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123.456 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +## time64 nanoseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50.123456789 + +query R +SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50.123456789 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123.456789 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123.456789 + +# just some floating point stuff happening in the result here +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456789 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456789 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + # test_extract_epoch query R diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index b63fa9950ae0..d4570dbc35f2 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1624,34 +1624,19 @@ _Alias of [date_part](#date_part)._ ### `extract` Returns a sub-field from a time value as an integer. -Similar to `date_part`, but with different arguments. ``` extract(field FROM source) ``` -#### Arguments - -- **field**: Part or field of the date to return. - The following date fields are supported: +Equivalent to calling `date_part('field', source)`. For example, these are equivalent: - - year - - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - - month - - week _(week of the year)_ - - day _(day of the month)_ - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow _(day of the week)_ - - doy _(day of the year)_ - - epoch _(seconds since Unix epoch)_ +```sql +extract(day FROM '2024-04-13'::date) +date_part('day', '2024-04-13'::date) +``` -- **source**: Source time expression to operate on. - Can be a constant, column, or function. +See [date_part](#date_part). ### `make_date` From edaf235828a90042eaf918ec4b3ee5ab2716f060 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 21 Mar 2024 08:12:18 -0700 Subject: [PATCH 33/35] doc: Updated known users list and usage dependency description (#9718) * minor: update known users and usage description --- docs/source/user-guide/example-usage.md | 10 +++++----- docs/source/user-guide/introduction.md | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index c5eefbdaf156..31b599ac3308 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -23,20 +23,20 @@ In this example some simple processing is performed on the [`example.csv`](https Even [`more code examples`](https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples) attached to the project. -## Add DataFusion as a dependency +## Add published DataFusion dependency Find latest available Datafusion version on [DataFusion's crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml -datafusion = "31" +datafusion = "latest_version" tokio = "1.0" ``` -## Add DataFusion latest codebase as a dependency +## Add latest non published DataFusion dependency -Cargo supports adding dependency directly from Github which allows testing out latest DataFusion codebase without waiting the code to be released to crates.io -according to the [DataFusion release schedule](https://github.com/apache/arrow-datafusion/blob/main/dev/release/README.md#release-process) +DataFusion changes are published to `crates.io` according to [release schedule](https://github.com/apache/arrow-datafusion/blob/main/dev/release/README.md#release-process) +In case if it is required to test out DataFusion changes which are merged but yet to be published, Cargo supports adding dependency directly to Github branch ```toml datafusion = { git = "https://github.com/apache/arrow-datafusion", branch = "main"} diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index ae2684699726..0e9d731c6e21 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -96,6 +96,7 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine +- [Comet](https://github.com/apache/arrow-datafusion-comet) Apache Spark native query execution plugin - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python From c5c9d3f57f361c6c01d0cb01c416f6a7e9dfd906 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 21 Mar 2024 11:18:14 -0400 Subject: [PATCH 34/35] Minor: improve documentation for `CommonSubexprEliminate` (#9700) --- .../optimizer/src/common_subexpr_eliminate.rs | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7b8eccad5133..e73885c6aaef 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -53,10 +53,32 @@ type ExprSet = HashMap; /// here is not such a good choose. type Identifier = String; -/// Perform Common Sub-expression Elimination optimization. +/// Performs Common Sub-expression Elimination optimization. /// -/// Currently only common sub-expressions within one logical plan will +/// This optimization improves query performance by computing expressions that +/// appear more than once and reusing those results rather than re-computing the +/// same value +/// +/// Currently only common sub-expressions within a single `LogicalPlan` are /// be eliminated. +/// +/// # Example +/// +/// Given a projection that computes the same expensive expression +/// multiple times such as parsing as string as a date with `to_date` twice: +/// +/// ```text +/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))]) +/// ``` +/// +/// This optimization will rewrite the plan to compute the common expression once +/// using a new `ProjectionExec` and then rewrite the original expressions to +/// refer to that new column. +/// +/// ```text +/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here +/// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once +/// ``` pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { From eda2ddfc123a0549c7df7fe0500b48bff1f76910 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 21 Mar 2024 11:07:40 -0700 Subject: [PATCH 35/35] build: modify code to comply with latest clippy requirement (#9725) * fix CI clippy * fix scalar size test * fix tests * fix tests --- datafusion/common/src/scalar/mod.rs | 3 ++- datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 2 +- datafusion/functions/benches/regx.rs | 4 ++-- datafusion/functions/benches/to_char.rs | 2 +- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 3 ++- datafusion/physical-expr/src/equivalence/class.rs | 6 +++--- datafusion/physical-expr/src/equivalence/ordering.rs | 4 ++-- datafusion/physical-expr/src/equivalence/properties.rs | 2 +- datafusion/physical-plan/src/sorts/partial_sort.rs | 2 +- datafusion/physical-plan/src/union.rs | 2 +- datafusion/substrait/src/serializer.rs | 1 + 12 files changed, 18 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d33b8b6e142c..2a99b667d8f1 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -4539,7 +4539,8 @@ mod tests { // The alignment requirements differ across architectures and // thus the size of the enum appears to as well - assert_eq!(std::mem::size_of::(), 48); + // The value can be changed depending on rust version + assert_eq!(std::mem::size_of::(), 64); } #[test] diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 59905d859dc8..8df16e7944d2 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -46,7 +46,7 @@ use tokio::task::JoinSet; /// same results #[tokio::test(flavor = "multi_thread")] async fn streaming_aggregate_test() { - let test_cases = vec![ + let test_cases = [ vec!["a"], vec!["b", "a"], vec!["c", "a"], diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index ea3ffadda391..7a227a91c455 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -379,7 +379,7 @@ mod test { let expr = col("a") + col("b"); let schema_a = make_schema_with_empty_metadata(vec![make_field("\"tableA\"", "a")]); - let schemas = vec![schema_a]; + let schemas = [schema_a]; let schemas = schemas.iter().collect::>(); let error = diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 5831e263b4eb..f22be5ba3532 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -44,7 +44,7 @@ fn data(rng: &mut ThreadRng) -> StringArray { } fn regex(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ ".*([A-Z]{1}).*".to_string(), "^(A).*".to_string(), r#"[\p{Letter}-]+"#.to_string(), @@ -60,7 +60,7 @@ fn regex(rng: &mut ThreadRng) -> StringArray { } fn flags(rng: &mut ThreadRng) -> StringArray { - let samples = vec![Some("i".to_string()), Some("im".to_string()), None]; + let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); for _ in 0..1000 { let sample = samples.choose(rng).unwrap(); diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 45a40f175da4..d9a153e64abc 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -64,7 +64,7 @@ fn data(rng: &mut ThreadRng) -> Date32Array { } fn patterns(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ "%Y:%m:%d".to_string(), "%d-%m-%Y".to_string(), "%d%m%Y".to_string(), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 61e002ece98b..1cbe7decf15b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -405,11 +405,12 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } +#[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), - // Evalaution encountered an error, contains the original expression + // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 280535f5e6be..58519c61cf1f 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -535,7 +535,7 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ + let entries = [ EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), // This group is meaningless should be removed EquivalenceClass::new(vec![lit(3), lit(3)]), @@ -543,11 +543,11 @@ mod tests { ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. - let expected = vec![ + let expected = [ EquivalenceClass::new(vec![lit(1), lit(2)]), EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries); + let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); eq_groups.remove_redundant_entries(); let eq_groups = eq_groups.classes; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index c7cb9e5f530e..1364d3a8c028 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -746,7 +746,7 @@ mod tests { // Generate a data that satisfies properties given let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ + let col_exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, @@ -815,7 +815,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a08e85b24162..5eb9d6eb1b86 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1793,7 +1793,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 500df6153fdb..2acb881246a4 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -578,7 +578,7 @@ mod tests { #[tokio::test] async fn test_partial_sort2() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); - let source_tables = vec![ + let source_tables = [ test::build_table_scan_i32( ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]), ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]), diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 7eaac74a5449..64322bd5f101 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -740,7 +740,7 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let options = SortOptions::default(); - let test_cases = vec![ + let test_cases = [ //-----------TEST CASE 1----------// ( // First child orderings diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index e8698253edb5..6b81e33dfc37 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,6 +27,7 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; +#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?;