From ecd4793d85b3863041cc6b2492e0c99cd35a5232 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Thu, 5 Dec 2024 23:44:19 +0800 Subject: [PATCH 01/26] bench: scalar regex match benchmark --- datafusion/core/Cargo.toml | 5 + .../benches/scalar_regex_match_query_sql.rs | 131 ++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 datafusion/core/benches/scalar_regex_match_query_sql.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 4706afc897c2..8b1ff5aa943b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -216,3 +216,8 @@ name = "topk_aggregate" harness = false name = "map_query_sql" required-features = ["nested_expressions"] + +[[bench]] +harness = false +name = "scalar_regex_match_query_sql" + diff --git a/datafusion/core/benches/scalar_regex_match_query_sql.rs b/datafusion/core/benches/scalar_regex_match_query_sql.rs new file mode 100644 index 000000000000..dbd74cdabb51 --- /dev/null +++ b/datafusion/core/benches/scalar_regex_match_query_sql.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, +}; +use arrow_array::StringArray; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::prelude::SessionContext; +use datafusion::{datasource::MemTable, error::Result}; +use rand::SeedableRng; +use rand::{rngs::StdRng, Rng}; +use std::sync::Arc; +use tokio::runtime::Runtime; + +fn query(ctx: &SessionContext, sql: &str) { + let rt = Runtime::new().unwrap(); + + // execute the query + let df = rt.block_on(ctx.sql(sql)).unwrap(); + rt.block_on(df.collect()).unwrap(); +} + +fn generate_random_string(rng: &mut StdRng, length: usize, charset: &[u8]) -> String { + (0..length) + .map(|_| { + let idx = rng.gen_range(0..charset.len()); + charset[idx] as char + }) + .collect() +} + +fn create_context( + batch_iter: usize, + batch_size: usize, + string_len: usize, + rand_seed: u64, + correct: &str, +) -> Result { + let mut rng = StdRng::seed_from_u64(rand_seed); + let charset = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,:/\\+-_!@#$%^&*()~'\"{}[]?"; + + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + + // define data. + let batches = (0..batch_iter) + .map(|_| { + let mut array = (0..batch_size - 128) + .map(|_| Some(generate_random_string(&mut rng, string_len, charset))) + .collect::>(); + for _ in 0..128 { + array.push(Some(correct.to_string())); + } + let array = StringArray::from(array); + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() + }) + .collect::>(); + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![batches])?; + ctx.register_table("t", Arc::new(provider))?; + + Ok(ctx) +} + +fn criterion_benchmark(c: &mut Criterion) { + let batch_iter = 128; + let batch_size = 4096; + c.bench_function("test email address pattern", |b| { + let correct = "test@eaxample.com"; + let sql = "select s from t where s ~ '^[a-zA-Z0-9_\\+\\-]+@[a-zA-Z0-9\\-]+\\.[a-zA-Z]{2,}$'"; + let ctx = create_context(batch_iter, batch_size, 64, 11111, correct).unwrap(); + b.iter(|| query(&ctx, sql)) + }); + + c.bench_function("test ip pattern", |b| { + let correct = "23.7.9.9"; + let ctx = create_context(batch_iter, batch_size, 16, 22222, correct).unwrap(); + let sql = "select s from t where s ~ '^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$'"; + b.iter(|| query(&ctx, sql)) + }); + + c.bench_function("test phone number pattern", |b| { + let correct = "1236788899"; + let sql = "select s from t where s ~ '^(\\+\\d{1,2}\\s?)?\\(?\\d{3}\\)?[\\s.-]?\\d{3}[\\s.-]?\\d{4}$'"; + let ctx = create_context(batch_iter, batch_size, 16, 33333, correct).unwrap(); + b.iter(|| query(&ctx, sql)) + }); + + c.bench_function("test html tag pattern", |b| { + let correct = "
Hello World
"; + let sql = "select s from t where s ~ '<([a-z1-6]+)>[^<]+'"; + let ctx = create_context(batch_iter, batch_size, 64, 44444, correct).unwrap(); + b.iter(|| query(&ctx, sql)) + }); + + c.bench_function("test url pattern", |b| { + let correct = "https://www.example.com"; + let sql = "select s from t where s ~ '^(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]$'"; + let ctx = create_context(batch_iter, batch_size, 64, 55555, correct).unwrap(); + b.iter(|| query(&ctx, sql)) + }); + + c.bench_function("test date pattern", |b| { + let correct = "2024-02-03"; + let sql = "select s from t where s ~ '[0-9]{4}-[0-9]{2}-[0-9]{2}'"; + let ctx = create_context(batch_iter, batch_size, 16, 66666, correct).unwrap(); + b.iter(|| query(&ctx, sql)) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); From c3e0951e7c2f05ea90f3e2be40f9f748ef6c224c Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 6 Dec 2024 17:58:44 +0100 Subject: [PATCH 02/26] refactor: migrate `LinearSearch` to `HashTable` (#13658) For #13433. --- .../src/windows/bounded_window_agg_exec.rs | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 07d99c8e7129..c1cfd91be052 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -65,7 +65,7 @@ use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use futures::stream::Stream; use futures::{ready, StreamExt}; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; use indexmap::IndexMap; use log::debug; @@ -442,16 +442,16 @@ pub struct LinearSearch { /// is ordered by a, b and the window expression contains a PARTITION BY b, a /// clause, this attribute stores [1, 0]. ordered_partition_by_indices: Vec, - /// We use this [`RawTable`] to calculate unique partitions for each new + /// We use this [`HashTable`] to calculate unique partitions for each new /// RecordBatch. First entry in the tuple is the hash value, the second /// entry is the unique ID for each partition (increments from 0 to n). - row_map_batch: RawTable<(u64, usize)>, - /// We use this [`RawTable`] to calculate the output columns that we can + row_map_batch: HashTable<(u64, usize)>, + /// We use this [`HashTable`] to calculate the output columns that we can /// produce at each cycle. First entry in the tuple is the hash value, the /// second entry is the unique ID for each partition (increments from 0 to n). /// The third entry stores how many new outputs are calculated for the /// corresponding partition. - row_map_out: RawTable<(u64, usize, usize)>, + row_map_out: HashTable<(u64, usize, usize)>, input_schema: SchemaRef, } @@ -610,8 +610,8 @@ impl LinearSearch { input_buffer_hashes: VecDeque::new(), random_state: Default::default(), ordered_partition_by_indices, - row_map_batch: RawTable::with_capacity(256), - row_map_out: RawTable::with_capacity(256), + row_map_batch: HashTable::with_capacity(256), + row_map_out: HashTable::with_capacity(256), input_schema, } } @@ -631,7 +631,7 @@ impl LinearSearch { // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition. let mut result: Vec<(PartitionKey, Vec)> = vec![]; for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) { - let entry = self.row_map_batch.get_mut(hash, |(_, group_idx)| { + let entry = self.row_map_batch.find_mut(hash, |(_, group_idx)| { // We can safely get the first index of the partition indices // since partition indices has one element during initialization. let row = get_row_at_idx(columns, row_idx as usize).unwrap(); @@ -641,8 +641,11 @@ impl LinearSearch { if let Some((_, group_idx)) = entry { result[*group_idx].1.push(row_idx) } else { - self.row_map_batch - .insert(hash, (hash, result.len()), |(hash, _)| *hash); + self.row_map_batch.insert_unique( + hash, + (hash, result.len()), + |(hash, _)| *hash, + ); let row = get_row_at_idx(columns, row_idx as usize)?; // This is a new partition its only index is row_idx for now. result.push((row, vec![row_idx])); @@ -667,7 +670,7 @@ impl LinearSearch { self.row_map_out.clear(); let mut partition_indices: Vec<(PartitionKey, Vec)> = vec![]; for (hash, row_idx) in self.input_buffer_hashes.iter().zip(0u32..) { - let entry = self.row_map_out.get_mut(*hash, |(_, group_idx, _)| { + let entry = self.row_map_out.find_mut(*hash, |(_, group_idx, _)| { let row = get_row_at_idx(&partition_by_columns, row_idx as usize).unwrap(); row == partition_indices[*group_idx].0 @@ -693,7 +696,7 @@ impl LinearSearch { if min_out == 0 { break; } - self.row_map_out.insert( + self.row_map_out.insert_unique( *hash, (*hash, partition_indices.len(), min_out), |(hash, _, _)| *hash, From 1e507ad0f99d8cbd4d1556306cd016e53b7278e9 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 6 Dec 2024 12:25:44 -0800 Subject: [PATCH 03/26] Minor: Comment temporary function for documentation migration (#13669) * Minor: Comment temporary function for documentation migration * Minor: Comment temporary function for documentation migration --- datafusion/core/src/bin/print_functions_docs.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index b58f6e47d333..8b453d5e9698 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -88,6 +88,7 @@ fn print_window_docs() -> Result { // the migration of UDF documentation generation from code based // to attribute based // To be removed +#[allow(dead_code)] fn save_doc_code_text(documentation: &Documentation, name: &str) { let attr_text = documentation.to_doc_attribute(); @@ -182,7 +183,7 @@ fn print_docs( }; // Temporary for doc gen migration, see `save_doc_code_text` comments - save_doc_code_text(documentation, &name); + // save_doc_code_text(documentation, &name); // first, the name, description and syntax example let _ = write!( From 61fd077c99dbe0f8d2038ff7d49b55a070fb6746 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Fri, 6 Dec 2024 13:13:05 -0800 Subject: [PATCH 04/26] Minor: Rephrase MSRV policy to be more explanatory (#13668) * Minor: Rephrase MSRV policy to be more explanatory Co-authored-by: Andrew Lamb * MSRV policy update --------- Co-authored-by: Andrew Lamb --- README.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 553097552418..2e4f2c347fe5 100644 --- a/README.md +++ b/README.md @@ -126,14 +126,17 @@ Optional features: ## Rust Version Compatibility Policy -DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support stable [4 latest -Rust versions](https://releases.rs) OR the stable minor Rust version as of 4 months, whichever is lower. +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. -If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. -We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) ## DataFusion API evolution policy From 67260a0a60c5b96e9639180a13fd297ab65c12e1 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sat, 7 Dec 2024 16:00:07 +0200 Subject: [PATCH 05/26] fix: repartitioned reads of CSV with custom line terminator (#13677) --- .../core/src/datasource/physical_plan/csv.rs | 4 ++- .../core/src/datasource/physical_plan/json.rs | 2 +- .../core/src/datasource/physical_plan/mod.rs | 11 +++--- .../sqllogictest/test_files/csv_files.slt | 36 ++++++++++++++----- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 0c41f69c7691..c54c663dca7d 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -612,11 +612,13 @@ impl FileOpener for CsvOpener { } let store = Arc::clone(&self.config.object_store); + let terminator = self.config.terminator; Ok(Box::pin(async move { // Current partition contains bytes [start_byte, end_byte) (might contain incomplete lines at boundaries) - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = + calculate_range(&file_meta, &store, terminator).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index c07e8ca74543..5c70968fbb42 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -273,7 +273,7 @@ impl FileOpener for JsonOpener { let file_compression_type = self.file_compression_type.to_owned(); Ok(Box::pin(async move { - let calculated_range = calculate_range(&file_meta, &store).await?; + let calculated_range = calculate_range(&file_meta, &store, None).await?; let range = match calculated_range { RangeCalculation::Range(None) => None, diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 449b7bb43519..3146d124d9f1 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -426,9 +426,11 @@ enum RangeCalculation { async fn calculate_range( file_meta: &FileMeta, store: &Arc, + terminator: Option, ) -> Result { let location = file_meta.location(); let file_size = file_meta.object_meta.size; + let newline = terminator.unwrap_or(b'\n'); match file_meta.range { None => Ok(RangeCalculation::Range(None)), @@ -436,13 +438,13 @@ async fn calculate_range( let (start, end) = (start as usize, end as usize); let start_delta = if start != 0 { - find_first_newline(store, location, start - 1, file_size).await? + find_first_newline(store, location, start - 1, file_size, newline).await? } else { 0 }; let end_delta = if end != file_size { - find_first_newline(store, location, end - 1, file_size).await? + find_first_newline(store, location, end - 1, file_size, newline).await? } else { 0 }; @@ -462,7 +464,7 @@ async fn calculate_range( /// within an object, such as a file, in an object store. /// /// This function scans the contents of the object starting from the specified `start` position -/// up to the `end` position, looking for the first occurrence of a newline (`'\n'`) character. +/// up to the `end` position, looking for the first occurrence of a newline character. /// It returns the position of the first newline relative to the start of the range. /// /// Returns a `Result` wrapping a `usize` that represents the position of the first newline character found within the specified range. If no newline is found, it returns the length of the scanned data, effectively indicating the end of the range. @@ -474,6 +476,7 @@ async fn find_first_newline( location: &Path, start: usize, end: usize, + newline: u8, ) -> Result { let options = GetOptions { range: Some(GetRange::Bounded(start..end)), @@ -486,7 +489,7 @@ async fn find_first_newline( let mut index = 0; while let Some(chunk) = result_stream.next().await.transpose()? { - if let Some(position) = chunk.iter().position(|&byte| byte == b'\n') { + if let Some(position) = chunk.iter().position(|&byte| byte == newline) { return Ok(index + position); } diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 01d0f4ac39bd..5906c6a19bb8 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -350,15 +350,33 @@ col2 TEXT LOCATION '../core/tests/data/cr_terminator.csv' OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true'); -# TODO: It should be passed but got the error: External error: query failed: DataFusion error: Object Store error: Generic LocalFileSystem error: Requested range was invalid -# See the issue: https://github.com/apache/datafusion/issues/12328 -# query TT -# select * from stored_table_with_cr_terminator; -# ---- -# id0 value0 -# id1 value1 -# id2 value2 -# id3 value3 +# Check single-thread reading of CSV with custom line terminator +statement ok +SET datafusion.optimizer.repartition_file_min_size = 10485760; + +query TT +select * from stored_table_with_cr_terminator; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 + +# Check repartitioned reading of CSV with custom line terminator +statement ok +SET datafusion.optimizer.repartition_file_min_size = 1; + +query TT +select * from stored_table_with_cr_terminator order by col1; +---- +id0 value0 +id1 value1 +id2 value2 +id3 value3 + +# Reset repartition_file_min_size to default value +statement ok +SET datafusion.optimizer.repartition_file_min_size = 10485760; statement ok drop table stored_table_with_cr_terminator; From 3618cfedc9b885fba293f66b58d5c73e9f68a8c9 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sat, 7 Dec 2024 22:13:56 +0100 Subject: [PATCH 06/26] chore: macros crate cleanup (#13685) * Remove unused dependencies from macros crate * rename macro lib to user_doc --- datafusion-cli/Cargo.lock | 1 - datafusion/macros/Cargo.toml | 4 ++-- datafusion/macros/src/{lib.rs => user_doc.rs} | 0 3 files changed, 2 insertions(+), 3 deletions(-) rename datafusion/macros/src/{lib.rs => user_doc.rs} (100%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c871b2fdda08..015bc1e0c382 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1517,7 +1517,6 @@ dependencies = [ name = "datafusion-macros" version = "43.0.0" dependencies = [ - "datafusion-doc", "proc-macro2", "quote", "syn", diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index c5ac9d08dffa..07aa07fa927a 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -32,11 +32,11 @@ workspace = true [lib] name = "datafusion_macros" -path = "src/lib.rs" +# lib.rs to be re-added in the future +path = "src/user_doc.rs" proc-macro = true [dependencies] -datafusion-doc = { workspace = true } proc-macro2 = "1.0" quote = "1.0.37" syn = { version = "2.0.79", features = ["full"] } diff --git a/datafusion/macros/src/lib.rs b/datafusion/macros/src/user_doc.rs similarity index 100% rename from datafusion/macros/src/lib.rs rename to datafusion/macros/src/user_doc.rs From d3e08608dabb1793c2a4f4922227c768d99609ca Mon Sep 17 00:00:00 2001 From: Jiashen Cao Date: Sun, 8 Dec 2024 02:48:57 -0500 Subject: [PATCH 07/26] Refactor regexplike signature (#13394) * update * update * update * clean up errors * fix flags types * fix failed example --- datafusion-examples/examples/regexp.rs | 2 +- datafusion/functions/src/regex/regexplike.rs | 50 +++++++++++-------- .../test_files/string/string_view.slt | 2 +- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/datafusion-examples/examples/regexp.rs b/datafusion-examples/examples/regexp.rs index 02e74bae22af..5419efd2faea 100644 --- a/datafusion-examples/examples/regexp.rs +++ b/datafusion-examples/examples/regexp.rs @@ -148,7 +148,7 @@ async fn main() -> Result<()> { // invalid flags will result in an error let result = ctx - .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', 4010, 'g')") + .sql(r"select regexp_like('\b4(?!000)\d\d\d\b', '4010', 'g')") .await? .collect() .await; diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 49e57776c7b8..1c826b12ef8f 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -81,26 +81,7 @@ impl RegexpLikeFunc { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8View]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8]), - TypeSignature::Exact(vec![Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8View]), - TypeSignature::Exact(vec![Utf8, LargeUtf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), - TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), - TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]), - TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]), - ], + vec![TypeSignature::String(2), TypeSignature::String(3)], Volatility::Immutable, ), } @@ -211,7 +192,34 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let flags = args[2].as_string::(); + let flags = match args[2].data_type() { + Utf8 => args[2].as_string::(), + LargeUtf8 => { + let large_string_array = args[2].as_string::(); + let string_vec: Vec> = (0..large_string_array.len()).map(|i| { + if large_string_array.is_null(i) { + None + } else { + Some(large_string_array.value(i)) + } + }) + .collect(); + + &GenericStringArray::::from(string_vec) + }, + _ => { + let string_view_array = args[2].as_string_view(); + let string_vec: Vec> = (0..string_view_array.len()).map(|i| { + if string_view_array.is_null(i) { + None + } else { + Some(string_view_array.value(i).to_string()) + } + }) + .collect(); + &GenericStringArray::::from(string_vec) + }, + }; if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index ebabaf7655ff..c37dd1ed3b4f 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -731,7 +731,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +01)Projection: regexp_like(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REGEXP_MATCH From a960c6d4d8eb4ca02ba927b053662222e60f3a1f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 8 Dec 2024 08:28:05 -0500 Subject: [PATCH 08/26] Performance: enable array allocation reuse (`ScalarFunctionArgs` gets owned `ColumnReference`) (#13637) * Improve documentation * Pass owned args to ScalarFunctionArgs * Update advanced_udf with example of reusing arrays * clarify rationale for cloning * clarify comments * fix expected output --- datafusion-examples/examples/advanced_udf.rs | 126 ++++++++++++++---- datafusion/expr/src/udf.rs | 14 +- .../functions/src/datetime/to_local_time.rs | 2 +- datafusion/functions/src/string/ascii.rs | 6 +- datafusion/functions/src/string/btrim.rs | 24 ++-- datafusion/functions/src/string/concat.rs | 12 +- datafusion/functions/src/string/concat_ws.rs | 8 +- datafusion/functions/src/string/ends_with.rs | 8 +- datafusion/functions/src/string/initcap.rs | 16 +-- datafusion/functions/src/string/ltrim.rs | 22 +-- .../functions/src/string/octet_length.rs | 20 +-- datafusion/functions/src/string/repeat.rs | 12 +- datafusion/functions/src/string/replace.rs | 6 +- datafusion/functions/src/string/rtrim.rs | 22 +-- datafusion/functions/src/string/split_part.rs | 8 +- .../functions/src/string/starts_with.rs | 2 +- .../functions/src/unicode/character_length.rs | 6 +- datafusion/functions/src/unicode/left.rs | 18 +-- datafusion/functions/src/unicode/lpad.rs | 24 ++-- datafusion/functions/src/unicode/reverse.rs | 6 +- datafusion/functions/src/unicode/right.rs | 18 +-- datafusion/functions/src/unicode/rpad.rs | 28 ++-- datafusion/functions/src/unicode/strpos.rs | 2 +- datafusion/functions/src/unicode/substr.rs | 56 ++++---- .../functions/src/unicode/substrindex.rs | 14 +- datafusion/functions/src/unicode/translate.rs | 12 +- .../physical-expr/src/scalar_function.rs | 8 +- 27 files changed, 288 insertions(+), 212 deletions(-) diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index aee3be6c9285..ae35cff6facf 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -27,9 +27,11 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::logical_expr::Volatility; use datafusion::prelude::*; -use datafusion_common::{internal_err, ScalarValue}; +use datafusion_common::{exec_err, internal_err, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, +}; /// This example shows how to use the full ScalarUDFImpl API to implement a user /// defined function. As in the `simple_udf.rs` example, this struct implements @@ -83,23 +85,27 @@ impl ScalarUDFImpl for PowUdf { Ok(DataType::Float64) } - /// This is the function that actually calculates the results. + /// This function actually calculates the results of the scalar function. + /// + /// This is the same way that functions provided with DataFusion are invoked, + /// which permits important special cases: /// - /// This is the same way that functions built into DataFusion are invoked, - /// which permits important special cases when one or both of the arguments - /// are single values (constants). For example `pow(a, 2)` + ///1. When one or both of the arguments are single values (constants). + /// For example `pow(a, 2)` + /// 2. When the input arrays can be reused (avoid allocating a new output array) /// /// However, it also means the implementation is more complex than when /// using `create_udf`. - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // The other fields of the `args` struct are used for more specialized + // uses, and are not needed in this example + let ScalarFunctionArgs { mut args, .. } = args; // DataFusion has arranged for the correct inputs to be passed to this // function, but we check again to make sure assert_eq!(args.len(), 2); - let (base, exp) = (&args[0], &args[1]); + // take ownership of arguments by popping in reverse order + let exp = args.pop().unwrap(); + let base = args.pop().unwrap(); assert_eq!(base.data_type(), DataType::Float64); assert_eq!(exp.data_type(), DataType::Float64); @@ -118,7 +124,7 @@ impl ScalarUDFImpl for PowUdf { ) => { // compute the output. Note DataFusion treats `None` as NULL. let res = match (base, exp) { - (Some(base), Some(exp)) => Some(base.powf(*exp)), + (Some(base), Some(exp)) => Some(base.powf(exp)), // one or both arguments were NULL _ => None, }; @@ -140,31 +146,33 @@ impl ScalarUDFImpl for PowUdf { // kernel creates very fast "vectorized" code and // handles things like null values for us. let res: Float64Array = - compute::unary(base_array, |base| base.powf(*exp)); + compute::unary(base_array, |base| base.powf(exp)); Arc::new(res) } }; Ok(ColumnarValue::Array(result_array)) } - // special case if the base is a constant (note this code is quite - // similar to the previous case, so we omit comments) + // special case if the base is a constant. + // + // Note this case is very similar to the previous case, so we could + // use the same pattern. However, for this case we demonstrate an + // even more advanced pattern to potentially avoid allocating a new array ( ColumnarValue::Scalar(ScalarValue::Float64(base)), ColumnarValue::Array(exp_array), ) => { let res = match base { None => new_null_array(exp_array.data_type(), exp_array.len()), - Some(base) => { - let exp_array = exp_array.as_primitive::(); - let res: Float64Array = - compute::unary(exp_array, |exp| base.powf(exp)); - Arc::new(res) - } + Some(base) => maybe_pow_in_place(base, exp_array)?, }; Ok(ColumnarValue::Array(res)) } - // Both arguments are arrays so we have to perform the calculation for every row + // Both arguments are arrays so we have to perform the calculation + // for every row + // + // Note this could also be done in place using `binary_mut` as + // is done in `maybe_pow_in_place` but here we use binary for simplicity (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => { let res: Float64Array = compute::binary( base_array.as_primitive::(), @@ -191,6 +199,52 @@ impl ScalarUDFImpl for PowUdf { } } +/// Evaluate `base ^ exp` *without* allocating a new array, if possible +fn maybe_pow_in_place(base: f64, exp_array: ArrayRef) -> Result { + // Calling `unary` creates a new array for the results. Avoiding + // allocations is a common optimization in performance critical code. + // arrow-rs allows this optimization via the `unary_mut` + // and `binary_mut` kernels in certain cases + // + // These kernels can only be used if there are no other references to + // the arrays (exp_array has to be the last remaining reference). + let owned_array = exp_array + // as in the previous example, we first downcast to &Float64Array + .as_primitive::() + // non-obviously, we call clone here to get an owned `Float64Array`. + // Calling clone() is relatively inexpensive as it increments + // some ref counts but doesn't clone the data) + // + // Once we have the owned Float64Array we can drop the original + // exp_array (untyped) reference + .clone(); + + // We *MUST* drop the reference to `exp_array` explicitly so that + // owned_array is the only reference remaining in this function. + // + // Note that depending on the query there may still be other references + // to the underlying buffers, which would prevent reuse. The only way to + // know for sure is the result of `compute::unary_mut` + drop(exp_array); + + // If we have the only reference, compute the result directly into the same + // allocation as was used for the input array + match compute::unary_mut(owned_array, |exp| base.powf(exp)) { + Err(_orig_array) => { + // unary_mut will return the original array if there are other + // references into the underling buffer (and thus reuse is + // impossible) + // + // In a real implementation, this case should fall back to + // calling `unary` and allocate a new array; In this example + // we will return an error for demonstration purposes + exec_err!("Could not reuse array for maybe_pow_in_place") + } + // a result of OK means the operation was run successfully + Ok(res) => Ok(Arc::new(res)), + } +} + /// In this example we register `PowUdf` as a user defined function /// and invoke it via the DataFrame API and SQL #[tokio::main] @@ -215,9 +269,29 @@ async fn main() -> Result<()> { // print the results df.show().await?; - // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL - let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?; - sql_df.show().await?; + // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using SQL + ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t") + .await? + .show() + .await?; + + // You can also invoke pow_in_place by passing a constant base and a + // column `a` as the exponent . If there is only a single + // reference to `a` the code works well + ctx.sql("SELECT pow(2, a) FROM t").await?.show().await?; + + // However, if there are multiple references to `a` in the evaluation + // the array storage can not be reused + let err = ctx + .sql("SELECT pow(2, a), pow(3, a) FROM t") + .await? + .show() + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "Execution error: Could not reuse array for maybe_pow_in_place" + ); Ok(()) } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index bf9c9f407eff..809c78f30eff 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -326,13 +326,15 @@ where } } +/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a +/// scalar function. pub struct ScalarFunctionArgs<'a> { - // The evaluated arguments to the function - pub args: &'a [ColumnarValue], - // The number of rows in record batch being evaluated + /// The evaluated arguments to the function + pub args: Vec, + /// The number of rows in record batch being evaluated pub number_rows: usize, - // The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) - // when creating the physical expression from the logical expression + /// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) + /// when creating the physical expression from the logical expression pub return_type: &'a DataType, } @@ -539,7 +541,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - self.invoke_batch(args.args, args.number_rows) + self.invoke_batch(&args.args, args.number_rows) } /// Invoke the function without `args`, instead the number of rows are provided, diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index eaa91d1140ba..9f95b780ea4f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -562,7 +562,7 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { - args: &[ColumnarValue::Scalar(input)], + args: vec![ColumnarValue::Scalar(input)], number_rows: 1, return_type: &expected.data_type(), }) diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 4f615b5b2c58..f366329b4f86 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -157,7 +157,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -166,7 +166,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i32, Int32, @@ -175,7 +175,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index ae79bb59f9c7..298d64f04ae9 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -152,9 +152,9 @@ mod tests { // String view cases for checking normal logic test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") - ))),], + )))], Ok(Some("alphabet")), &str, Utf8View, @@ -162,7 +162,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet")), @@ -172,7 +172,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -185,7 +185,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -200,7 +200,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -214,7 +214,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabetxxx" )))), @@ -228,7 +228,7 @@ mod tests { // String cases test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -238,7 +238,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -248,7 +248,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -259,7 +259,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -270,7 +270,7 @@ mod tests { ); test_function!( BTrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 576c891ce467..895a7cdbf308 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -388,7 +388,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -400,7 +400,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -412,7 +412,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(Some("")), &str, Utf8, @@ -420,7 +420,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), @@ -433,7 +433,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), ColumnarValue::Scalar(ScalarValue::from("cc")), @@ -445,7 +445,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), ], diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 610c4f0be697..7db8dbec4a71 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -404,7 +404,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), @@ -417,7 +417,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], @@ -428,7 +428,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::from("bb")), @@ -441,7 +441,7 @@ mod tests { ); test_function!( ConcatWsFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("|")), ColumnarValue::Scalar(ScalarValue::from("aa")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index fc7fc04f4363..1632fdd9943e 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -138,7 +138,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from("alph")), ], @@ -149,7 +149,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from("bet")), ], @@ -160,7 +160,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("alph")), ], @@ -171,7 +171,7 @@ mod tests { ); test_function!( EndsWithFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index a9090b0a6f43..338a89091d29 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -163,7 +163,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], + vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], Ok(Some("Hi Thomas")), &str, Utf8, @@ -171,7 +171,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + vec![ColumnarValue::Scalar(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -179,7 +179,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + vec![ColumnarValue::Scalar(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -187,7 +187,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, @@ -195,7 +195,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "hi THOMAS".to_string() )))], Ok(Some("Hi Thomas")), @@ -205,7 +205,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() )))], Ok(Some("Hi Thomas With M0re Than 12 Chars")), @@ -215,7 +215,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( "".to_string() )))], Ok(Some("")), @@ -225,7 +225,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))], Ok(None), &str, Utf8, diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index e0e83d1b01e3..b3e7f0bf007d 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -148,7 +148,7 @@ mod tests { // String view cases for checking normal logic test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -158,7 +158,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some("alphabet ")), @@ -168,7 +168,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -181,7 +181,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -196,7 +196,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -210,7 +210,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "xxxalphabetalphabet" )))), @@ -224,7 +224,7 @@ mod tests { // String cases test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -234,7 +234,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet ")), @@ -244,7 +244,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), ], @@ -255,7 +255,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -266,7 +266,7 @@ mod tests { ); test_function!( LtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 2dbfa6746d61..26355556ff07 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -140,7 +140,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], exec_err!( "The OCTET_LENGTH function can only accept strings, but got Int32." ), @@ -150,7 +150,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Array(Arc::new(StringArray::from(vec![ + vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![ String::from("chars"), String::from("chars2"), ])))], @@ -161,7 +161,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) ], @@ -172,7 +172,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("chars") )))], Ok(Some(5)), @@ -182,7 +182,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("josé") )))], Ok(Some(5)), @@ -192,7 +192,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("") )))], Ok(Some(0)), @@ -202,7 +202,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))], Ok(None), i32, Int32, @@ -210,7 +210,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("joséjoséjoséjosé") )))], Ok(Some(20)), @@ -220,7 +220,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("josé") )))], Ok(Some(5)), @@ -230,7 +230,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("") )))], Ok(Some(0)), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 4140a9b913ff..d16508c6af5a 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -171,7 +171,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -182,7 +182,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -193,7 +193,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -205,7 +205,7 @@ mod tests { test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -216,7 +216,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ], @@ -227,7 +227,7 @@ mod tests { ); test_function!( RepeatFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 2439799f96d7..9b71d3871ea8 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -157,7 +157,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), @@ -170,7 +170,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( "aabbb" )))), @@ -185,7 +185,7 @@ mod tests { test_function!( ReplaceFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "aabbbcw" )))), diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index b4fe8d432432..ff8430f1530e 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -151,7 +151,7 @@ mod tests { // String view cases for checking normal logic test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -161,7 +161,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -171,7 +171,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -184,7 +184,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -199,7 +199,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -213,7 +213,7 @@ mod tests { // Special string view case for checking unlined output(len > 12) test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabetalphabetxxx" )))), @@ -227,7 +227,7 @@ mod tests { // String cases test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from("alphabet ") ))),], Ok(Some("alphabet")), @@ -237,7 +237,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( String::from(" alphabet ") ))),], Ok(Some(" alphabet")), @@ -247,7 +247,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t ")))), ], @@ -258,7 +258,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), ], @@ -269,7 +269,7 @@ mod tests { ); test_function!( RtrimFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ], diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index e55325db756d..40bdd3ad01b2 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -270,7 +270,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -284,7 +284,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -298,7 +298,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), @@ -312,7 +312,7 @@ mod tests { ); test_function!( SplitPartFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 36dbd8167b4e..7354fda09584 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -159,7 +159,7 @@ mod tests { for (args, expected) in test_cases { test_function!( StartsWithFunc::new(), - &args, + args, Ok(expected), bool, Boolean, diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 726822a8f887..822bdca9aca8 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -176,7 +176,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -185,7 +185,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i64, Int64, @@ -194,7 +194,7 @@ mod tests { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index ef2802340b14..e583523d84a0 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -188,7 +188,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -199,7 +199,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(200i64)), ], @@ -210,7 +210,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), ], @@ -221,7 +221,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-200i64)), ], @@ -232,7 +232,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -243,7 +243,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -254,7 +254,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -265,7 +265,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -276,7 +276,7 @@ mod tests { ); test_function!( LeftFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 6c8a4ec97bb0..f1750d2277ca 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -298,7 +298,7 @@ mod tests { ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -310,7 +310,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -322,7 +322,7 @@ mod tests { test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH) ], @@ -337,7 +337,7 @@ mod tests { // utf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -350,7 +350,7 @@ mod tests { // utf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -363,7 +363,7 @@ mod tests { // utf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) @@ -377,7 +377,7 @@ mod tests { // largeutf8, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -390,7 +390,7 @@ mod tests { // largeutf8, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -403,7 +403,7 @@ mod tests { // largeutf8, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) @@ -417,7 +417,7 @@ mod tests { // utf8view, utf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) @@ -430,7 +430,7 @@ mod tests { // utf8view, largeutf8 test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) @@ -443,7 +443,7 @@ mod tests { // utf8view, utf8view test_function!( LPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), ColumnarValue::Scalar($LENGTH), ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 38c1f23cbd5a..8e3cf8845f98 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -151,7 +151,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, &str, Utf8, @@ -160,7 +160,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, &str, LargeUtf8, @@ -169,7 +169,7 @@ mod tests { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, &str, Utf8, diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 1586e23eb8aa..4e414fbae5cb 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -192,7 +192,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -203,7 +203,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(200i64)), ], @@ -214,7 +214,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), ], @@ -225,7 +225,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(-200i64)), ], @@ -236,7 +236,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -247,7 +247,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -258,7 +258,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abcde")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -269,7 +269,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -280,7 +280,7 @@ mod tests { ); test_function!( RightFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 6e6bde3e177c..d5a0079c72aa 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -319,7 +319,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -330,7 +330,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -341,7 +341,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -352,7 +352,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -363,7 +363,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -374,7 +374,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -386,7 +386,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(21i64)), ColumnarValue::Scalar(ScalarValue::from("abcdef")), @@ -398,7 +398,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from(" ")), @@ -410,7 +410,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("")), @@ -422,7 +422,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -434,7 +434,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -446,7 +446,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("hi")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::Utf8(None)), @@ -458,7 +458,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(10i64)), ColumnarValue::Scalar(ScalarValue::from("xy")), @@ -470,7 +470,7 @@ mod tests { ); test_function!( RPadFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("josé")), ColumnarValue::Scalar(ScalarValue::from(10i64)), ColumnarValue::Scalar(ScalarValue::from("éñ")), diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 5d1986e44c92..569af87a4b50 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -218,7 +218,7 @@ mod tests { ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { test_function!( StrposFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))), ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))), ], diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 0ac050c707bf..141984cf2674 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -523,7 +523,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(None)), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], @@ -534,7 +534,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -547,7 +547,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "this és longer than 12B" )))), @@ -561,7 +561,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "this is longer than 12B" )))), @@ -574,7 +574,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "joséésoj" )))), @@ -587,7 +587,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -601,7 +601,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( "alphabet" )))), @@ -615,7 +615,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ], @@ -626,7 +626,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], @@ -637,7 +637,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ], @@ -648,7 +648,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], @@ -659,7 +659,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(2i64)), ], @@ -670,7 +670,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ], @@ -681,7 +681,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], @@ -692,7 +692,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(30i64)), ], @@ -703,7 +703,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ], @@ -714,7 +714,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -726,7 +726,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::from(20i64)), @@ -738,7 +738,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(0i64)), ColumnarValue::Scalar(ScalarValue::from(5i64)), @@ -751,7 +751,7 @@ mod tests { // starting from 5 (10 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(10i64)), @@ -764,7 +764,7 @@ mod tests { // starting from -1 (4 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(4i64)), @@ -777,7 +777,7 @@ mod tests { // starting from 0 (5 + -5) test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(-5i64)), ColumnarValue::Scalar(ScalarValue::from(5i64)), @@ -789,7 +789,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::Int64(None)), ColumnarValue::Scalar(ScalarValue::from(20i64)), @@ -801,7 +801,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(3i64)), ColumnarValue::Scalar(ScalarValue::Int64(None)), @@ -813,7 +813,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("alphabet")), ColumnarValue::Scalar(ScalarValue::from(1i64)), ColumnarValue::Scalar(ScalarValue::from(-1i64)), @@ -825,7 +825,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("joséésoj")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -851,7 +851,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("abc")), ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), ], @@ -862,7 +862,7 @@ mod tests { ); test_function!( SubstrFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("overflow")), ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), ColumnarValue::Scalar(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 825666b0455e..61cd989bb964 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -253,7 +253,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(1i64)), @@ -265,7 +265,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(2i64)), @@ -277,7 +277,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(-2i64)), @@ -289,7 +289,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(-1i64)), @@ -301,7 +301,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(0i64)), @@ -313,7 +313,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("")), ColumnarValue::Scalar(ScalarValue::from(".")), ColumnarValue::Scalar(ScalarValue::from(1i64)), @@ -325,7 +325,7 @@ mod tests { ); test_function!( SubstrIndexFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), ColumnarValue::Scalar(ScalarValue::from("")), ColumnarValue::Scalar(ScalarValue::from(1i64)), diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 780603777133..9257b0b04e61 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -201,7 +201,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -213,7 +213,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -225,7 +225,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), ColumnarValue::Scalar(ScalarValue::from("ax")) @@ -237,7 +237,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::Utf8(None)) @@ -249,7 +249,7 @@ mod tests { ); test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), ColumnarValue::Scalar(ScalarValue::from("éñí")), ColumnarValue::Scalar(ScalarValue::from("óü")), @@ -262,7 +262,7 @@ mod tests { #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), - &[ + vec![ ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::from("ax")), diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 45f77325eea3..e312d5de59fb 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -134,20 +134,20 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let inputs = self + let args = self .args .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; - let input_empty = inputs.is_empty(); - let input_all_scalar = inputs + let input_empty = args.is_empty(); + let input_all_scalar = args .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { - args: inputs.as_slice(), + args, number_rows: batch.num_rows(), return_type: &self.return_type, })?; From d39852de4363255bb7784c6e6eefd41d169675ab Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 8 Dec 2024 21:58:18 +0800 Subject: [PATCH 09/26] Temporary fix for CI (#13689) --- datafusion/core/Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 8b1ff5aa943b..5241603f3aa2 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -132,6 +132,10 @@ xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] +# Temporary fix for https://github.com/apache/datafusion/issues/13686 +# TODO: Remove it once the upstream has a fix +lexical-write-integer = { version = "=1.0.2" } + arrow-buffer = { workspace = true } async-trait = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } From 98372cc302053b4e5f5ccfa88045c0363eb289d9 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 8 Dec 2024 23:37:53 +0800 Subject: [PATCH 10/26] refactor: use `LazyLock` in the `user_doc` macro (#13684) * refactor: use `LazyLock` in the `user_doc` macro * Fix cargo doc * Update datafusion/macros/src/lib.rs * Fix doc comment --------- Co-authored-by: Oleks V --- .../src/approx_distinct.rs | 2 - .../functions-aggregate/src/approx_median.rs | 2 - .../src/approx_percentile_cont.rs | 3 +- .../src/approx_percentile_cont_with_weight.rs | 3 +- .../functions-aggregate/src/array_agg.rs | 3 +- datafusion/functions-aggregate/src/average.rs | 3 +- .../functions-aggregate/src/bool_and_or.rs | 2 - .../functions-aggregate/src/correlation.rs | 3 +- datafusion/functions-aggregate/src/count.rs | 3 +- .../functions-aggregate/src/covariance.rs | 2 - .../functions-aggregate/src/first_last.rs | 3 +- .../functions-aggregate/src/grouping.rs | 2 - datafusion/functions-aggregate/src/median.rs | 3 +- datafusion/functions-aggregate/src/min_max.rs | 2 - .../functions-aggregate/src/nth_value.rs | 3 +- datafusion/functions-aggregate/src/stddev.rs | 3 +- .../functions-aggregate/src/string_agg.rs | 2 - datafusion/functions-aggregate/src/sum.rs | 2 - .../functions-aggregate/src/variance.rs | 2 - datafusion/functions/src/datetime/to_date.rs | 2 - datafusion/functions/src/math/abs.rs | 3 +- datafusion/functions/src/string/ltrim.rs | 2 - datafusion/macros/src/user_doc.rs | 105 ++++++++++-------- 23 files changed, 68 insertions(+), 92 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 74691ba740fd..1d378fff176f 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -31,7 +31,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{ downcast_value, internal_err, not_impl_err, DataFusionError, Result, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -42,7 +41,6 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::marker::PhantomData; -use std::sync::OnceLock; make_udaf_expr_and_func!( ApproxDistinct, diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index d4441da61292..5d174a752296 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -19,13 +19,11 @@ use std::any::Any; use std::fmt::Debug; -use std::sync::OnceLock; use arrow::{datatypes::DataType, datatypes::Field}; use arrow_schema::DataType::{Float64, UInt64}; use datafusion_common::{not_impl_err, plan_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 13407fecf220..61424e8f2445 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; @@ -35,7 +35,6 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 485874aeb284..10b9b06f1f94 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::{ array::ArrayRef, @@ -27,7 +27,6 @@ use arrow::{ use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 98530a9fc236..b75de83f6ace 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -25,7 +25,6 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{internal_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, Signature, Volatility}; @@ -36,7 +35,7 @@ use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; make_udaf_expr_and_func!( ArrayAgg, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 65ca441517a0..18874f831e9d 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -42,14 +42,13 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: filtered_null_mask, set_nulls, }; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; use std::any::Any; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; make_udaf_expr_and_func!( Avg, diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b5b20f43b3e..29dfc68e0576 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -19,7 +19,6 @@ use std::any::Any; use std::mem::size_of_val; -use std::sync::OnceLock; use arrow::array::ArrayRef; use arrow::array::BooleanArray; @@ -38,7 +37,6 @@ use datafusion_expr::{ Signature, Volatility, }; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use datafusion_macros::user_doc; diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index b40555bf6c7f..a0ccdb0ae7d0 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::compute::{and, filter, is_not_null}; use arrow::{ @@ -31,7 +31,6 @@ use arrow::{ use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 550df8cb4f7d..b4164c211c35 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -17,7 +17,6 @@ use ahash::RandomState; use datafusion_common::stats::Precision; -use datafusion_doc::DocSection; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; @@ -25,7 +24,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::ops::BitAnd; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::{ array::{ArrayRef, AsArray}, diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index adb546e4d906..ffbf2ceef052 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -19,7 +19,6 @@ use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::OnceLock; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, @@ -31,7 +30,6 @@ use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, type_coercion::aggregates::NUMERICS, diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index f3e66edbc009..9ad55d91a68b 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; @@ -29,7 +29,6 @@ use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 36bdf68c1b0e..445774ff11e7 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -19,12 +19,10 @@ use std::any::Any; use std::fmt; -use std::sync::OnceLock; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::{not_impl_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index db5fbf00165f..70f192c32ae1 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -18,7 +18,7 @@ use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{downcast_integer, ArrowNumericType}; use arrow::{ @@ -34,7 +34,6 @@ use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; use datafusion_common::{DataFusionError, HashSet, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index acbeebaad68b..a0f7634c5fa8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -55,7 +55,6 @@ use arrow::datatypes::{ use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; -use datafusion_doc::DocSection; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, @@ -65,7 +64,6 @@ use datafusion_macros::user_doc; use half::f16; use std::mem::size_of_val; use std::ops::Deref; -use std::sync::OnceLock; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { // make sure that the input types only has one element. diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 15b9e97516ca..8252fd6baaa3 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -21,14 +21,13 @@ use std::any::Any; use std::collections::VecDeque; use std::mem::{size_of, size_of_val}; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 09a39e342cce..adf86a128cfb 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -20,14 +20,13 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::mem::align_of_val; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::Float64Array; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 5a52bec55f15..7643b44e11d5 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -22,7 +22,6 @@ use arrow_schema::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, @@ -31,7 +30,6 @@ use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; use std::any::Any; use std::mem::size_of_val; -use std::sync::OnceLock; make_udaf_expr_and_func!( StringAgg, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index ccc6ee3cf925..6c2854f6bc24 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -22,7 +22,6 @@ use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; use std::collections::HashSet; use std::mem::{size_of, size_of_val}; -use std::sync::OnceLock; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; @@ -35,7 +34,6 @@ use arrow::datatypes::{ }; use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; -use datafusion_doc::DocSection; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 70b10734088f..8aa7a40ce320 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -25,13 +25,11 @@ use arrow::{ datatypes::{DataType, Field}, }; use std::mem::{size_of, size_of_val}; -use std::sync::OnceLock; use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; -use datafusion_doc::DocSection; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index e2edea843e98..091d0ba37644 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -22,13 +22,11 @@ use arrow::error::ArrowError::ParseError; use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; use datafusion_common::error::DataFusionError; use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; -use std::sync::OnceLock; #[user_doc( doc_section(label = "Time and Date Functions"), diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index c0c7c6f0f6b6..e3d448083e26 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -18,7 +18,7 @@ //! math expressions use std::any::Any; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use arrow::array::{ ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, @@ -27,7 +27,6 @@ use arrow::array::{ use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; -use datafusion_doc::DocSection; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index b3e7f0bf007d..0bc62ee5000d 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -22,12 +22,10 @@ use std::any::Any; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; -use datafusion_doc::DocSection; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; -use std::sync::OnceLock; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index 54b688ac2a49..441b3db2a133 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -26,16 +26,19 @@ use syn::{parse_macro_input, DeriveInput, LitStr}; /// declared on `AggregateUDF`, `WindowUDFImpl`, `ScalarUDFImpl` traits. /// /// Example: +/// ```ignore /// #[user_doc( /// doc_section(include = "true", label = "Time and Date Functions"), -/// description = r"Converts a value to a date (`YYYY-MM-DD`)." -/// sql_example = "```sql\n\ -/// \> select to_date('2023-01-31');\n\ -/// +-----------------------------+\n\ -/// | to_date(Utf8(\"2023-01-31\")) |\n\ -/// +-----------------------------+\n\ -/// | 2023-01-31 |\n\ -/// +-----------------------------+\n\"), +/// description = r"Converts a value to a date (`YYYY-MM-DD`).", +/// syntax_example = "to_date('2017-05-31', '%Y-%m-%d')", +/// sql_example = r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, /// standard_argument(name = "expression", prefix = "String"), /// argument( /// name = "format_n", @@ -48,40 +51,50 @@ use syn::{parse_macro_input, DeriveInput, LitStr}; /// pub struct ToDateFunc { /// signature: Signature, /// } -/// +/// ``` /// will generate the following code /// -/// #[derive(Debug)] pub struct ToDateFunc { signature : Signature, } -/// use datafusion_doc :: DocSection; -/// use datafusion_doc :: DocumentationBuilder; -/// static DOCUMENTATION : OnceLock < Documentation > = OnceLock :: new(); -/// impl ToDateFunc -/// { -/// fn doc(& self) -> Option < & Documentation > -/// { -/// Some(DOCUMENTATION.get_or_init(|| -/// { -/// Documentation :: -/// builder(DocSection -/// { -/// include : true, label : "Time and Date Functions", description -/// : None -/// }, r"Converts a value to a date (`YYYY-MM-DD`).") -/// .with_syntax_example("to_date('2017-05-31', '%Y-%m-%d')".to_string(),"```sql\n\ -/// \> select to_date('2023-01-31');\n\ -/// +-----------------------------+\n\ -/// | to_date(Utf8(\"2023-01-31\")) |\n\ -/// +-----------------------------+\n\ -/// | 2023-01-31 |\n\ -/// +-----------------------------+\n\) -/// .with_standard_argument("expression", "String".into()) -/// .with_argument("format_n", -/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order -/// they appear with the first successful one being returned. If none of the formats successfully parse the expression -/// an error will be returned.").build() -/// })) +/// ```ignore +/// pub struct ToDateFunc { +/// signature: Signature, +/// } +/// impl ToDateFunc { +/// fn doc(&self) -> Option<&datafusion_doc::Documentation> { +/// static DOCUMENTATION: std::sync::LazyLock< +/// datafusion_doc::Documentation, +/// > = std::sync::LazyLock::new(|| { +/// datafusion_doc::Documentation::builder( +/// datafusion_doc::DocSection { +/// include: true, +/// label: "Time and Date Functions", +/// description: None, +/// }, +/// r"Converts a value to a date (`YYYY-MM-DD`).".to_string(), +/// "to_date('2017-05-31', '%Y-%m-%d')".to_string(), +/// ) +/// .with_sql_example( +/// r#"```sql +/// > select to_date('2023-01-31'); +/// +-----------------------------+ +/// | to_date(Utf8(\"2023-01-31\")) | +/// +-----------------------------+ +/// | 2023-01-31 | +/// +-----------------------------+ +/// ```"#, +/// ) +/// .with_standard_argument("expression", "String".into()) +/// .with_argument( +/// "format_n", +/// r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order +/// they appear with the first successful one being returned. If none of the formats successfully parse the expression +/// an error will be returned.", +/// ) +/// .build() +/// }); +/// Some(&DOCUMENTATION) /// } /// } +/// ``` #[proc_macro_attribute] pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { let mut doc_section_include: Option = None; @@ -235,19 +248,14 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { } }); - let lock_name: proc_macro2::TokenStream = - format!("{name}_DOCUMENTATION").parse().unwrap(); - let generated = quote! { #input - static #lock_name: OnceLock = OnceLock::new(); - impl #name { - - fn doc(&self) -> Option<&Documentation> { - Some(#lock_name.get_or_init(|| { - Documentation::builder(DocSection { include: #doc_section_include, label: #doc_section_lbl, description: #doc_section_description }, + fn doc(&self) -> Option<&datafusion_doc::Documentation> { + static DOCUMENTATION: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + datafusion_doc::Documentation::builder(datafusion_doc::DocSection { include: #doc_section_include, label: #doc_section_lbl, description: #doc_section_description }, #description.to_string(), #syntax_example.to_string()) #sql_example #alt_syntax_example @@ -255,7 +263,8 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { #(#udf_args)* #(#related_udfs)* .build() - })) + }); + Some(&DOCUMENTATION) } } }; From e8226f5f8dc6a3fd539524f7d250e6524b0259ee Mon Sep 17 00:00:00 2001 From: Alexander Huszagh Date: Sun, 8 Dec 2024 19:51:44 -0600 Subject: [PATCH 11/26] Unlock lexical-write-integer version. (#13693) Issue was patched as of lexical release 1.0.5. Reverts #13689 Closes #13686 --- datafusion/core/Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 5241603f3aa2..8b1ff5aa943b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -132,10 +132,6 @@ xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] -# Temporary fix for https://github.com/apache/datafusion/issues/13686 -# TODO: Remove it once the upstream has a fix -lexical-write-integer = { version = "=1.0.2" } - arrow-buffer = { workspace = true } async-trait = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } From bd91271b884a55a8f05df9afb0c82ac098572de6 Mon Sep 17 00:00:00 2001 From: Mustafa Akur Date: Mon, 9 Dec 2024 04:05:37 -0800 Subject: [PATCH 12/26] Minor: Use `div_ceil` --- .../physical-plan/src/joins/cross_join.rs | 27 ++++++++++--------- .../src/joins/nested_loop_join.rs | 24 ++++++++--------- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index f53fe13df15e..8bf675e87362 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -190,18 +190,21 @@ async fn load_left_input( // Load all batches and count the rows let (batches, _metrics, reservation) = stream - .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); - // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; - // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); - // Push batch to output - acc.0.push(batch); - Ok(acc) - }) + .try_fold( + (Vec::new(), metrics, reservation), + |(mut batches, metrics, mut reservation), batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + reservation.try_grow(batch_size)?; + // Update metrics + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); + // Push batch to output + batches.push(batch); + Ok((batches, metrics, reservation)) + }, + ) .await?; let merged_batch = concat_batches(&left_schema, &batches)?; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 2beeb92da499..d174564178df 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -45,7 +45,6 @@ use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow::util::bit_util; use datafusion_common::{ exec_datafusion_err, internal_err, JoinSide, Result, Statistics, }; @@ -440,17 +439,17 @@ async fn collect_left_input( let (batches, metrics, mut reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |mut acc, batch| async { + |(mut batches, metrics, mut reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.2.try_grow(batch_size)?; + reservation.try_grow(batch_size)?; // Update metrics - acc.1.build_mem_used.add(batch_size); - acc.1.build_input_batches.add(1); - acc.1.build_input_rows.add(batch.num_rows()); + metrics.build_mem_used.add(batch_size); + metrics.build_input_batches.add(1); + metrics.build_input_rows.add(batch.num_rows()); // Push batch to output - acc.0.push(batch); - Ok(acc) + batches.push(batch); + Ok((batches, metrics, reservation)) }, ) .await?; @@ -459,14 +458,13 @@ async fn collect_left_input( // Reserve memory for visited_left_side bitmap if required by join type let visited_left_side = if with_visited_left_side { - // TODO: Replace `ceil` wrapper with stable `div_cell` after - // https://github.com/rust-lang/rust/issues/88581 - let buffer_size = bit_util::ceil(merged_batch.num_rows(), 8); + let n_rows = merged_batch.num_rows(); + let buffer_size = n_rows.div_ceil(8); reservation.try_grow(buffer_size)?; metrics.build_mem_used.add(buffer_size); - let mut buffer = BooleanBufferBuilder::new(merged_batch.num_rows()); - buffer.append_n(merged_batch.num_rows(), false); + let mut buffer = BooleanBufferBuilder::new(n_rows); + buffer.append_n(n_rows, false); buffer } else { BooleanBufferBuilder::new(0) From 45926ab28d8a7962f2fec41bc23ff89eda83c0e4 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 10 Dec 2024 04:26:19 +0800 Subject: [PATCH 13/26] Fix hash join with sort push down (#13560) * fix: join with sort push down * chore: insert some value * apply suggestion * recover handle_costom_pushdown change * apply suggestion * add more test * add partition --- .../src/physical_optimizer/sort_pushdown.rs | 101 +++++++++++ datafusion/sqllogictest/test_files/joins.slt | 171 +++++++++++++----- 2 files changed, 228 insertions(+), 44 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index d48c7118cb8e..6c761f674b3b 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -28,6 +28,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::tree_node::PlanContext; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use arrow_schema::SchemaRef; use datafusion_common::tree_node::{ ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, @@ -38,6 +39,8 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::PhysicalSortRequirement; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::HashJoinExec; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -294,6 +297,8 @@ fn pushdown_requirement_to_children( .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req])) } + } else if let Some(hash_join) = plan.as_any().downcast_ref::() { + handle_hash_join(hash_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -606,6 +611,102 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() || !plan.maintains_input_order()[1] { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .into_iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[1].schema(); + let updated_columns = Arc::clone(&req.expr) + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + Some(LexRequirement::new(updated_parent_req)), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec { + let map_fields = |schema: SchemaRef, side: JoinSide| { + schema + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { index, side }) + .collect::>() + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + map_fields(plan.left().schema(), JoinSide::Left) + .into_iter() + .chain(map_fields(plan.right().schema(), JoinSide::Right)) + .collect::>() + } + JoinType::RightSemi | JoinType::RightAnti => { + map_fields(plan.right().schema(), JoinSide::Right) + } + _ => unreachable!("unexpected join type: {}", plan.join_type()), + } +} + /// Define the Requirements Compatibility #[derive(Debug)] enum RequirementsCompatibility { diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index e636e93007a4..62f625119897 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2864,13 +2864,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2905,13 +2905,13 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2967,10 +2967,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3003,10 +3003,10 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3061,13 +3061,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3083,13 +3083,13 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 09)--------CoalesceBatchesExec: target_batch_size=2 10)----------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 11)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -3143,10 +3143,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3160,10 +3160,10 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 -05)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--CoalesceBatchesExec: target_batch_size=2 +03)----HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -4313,3 +4313,86 @@ physical_plan 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] 06)--------MemoryExec: partitions=1, partition_sizes=[1] + +# Test hash join sort push down +# Issue: https://github.com/apache/datafusion/issues/13559 +statement ok +CREATE TABLE test(a INT, b INT, c INT) + +statement ok +insert into test values (1,2,3), (4,5,6), (null, 7, 8), (8, null, 9), (9, 10, null) + +statement ok +set datafusion.execution.target_partitions = 2; + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +logical_plan +01)Sort: test.c DESC NULLS FIRST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortPreservingMergeExec: [c@2 DESC] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +logical_plan +01)Sort: test.c DESC NULLS LAST +02)--LeftSemi Join: test.a = __correlated_sq_1.a +03)----TableScan: test projection=[a, b, c] +04)----SubqueryAlias: __correlated_sq_1 +05)------Projection: test.a +06)--------Filter: test.b > Int32(3) +07)----------TableScan: test projection=[a, b] +physical_plan +01)SortPreservingMergeExec: [c@2 DESC NULLS LAST] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(a@0, a@0)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +06)----------CoalesceBatchesExec: target_batch_size=3 +07)------------FilterExec: b@1 > 3, projection=[a@0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] +10)------SortExec: expr=[c@2 DESC NULLS LAST], preserve_partitioning=[true] +11)--------CoalesceBatchesExec: target_batch_size=3 +12)----------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 +13)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +14)--------------MemoryExec: partitions=1, partition_sizes=[1] + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls first; +---- +9 10 NULL +4 5 6 + +query III +select * from test where a in (select a from test where b > 3) order by c desc nulls last; +---- +4 5 6 +9 10 NULL + +statement ok +DROP TABLE test + +statement ok +set datafusion.execution.target_partitions = 1; From 16d2ab133a2c3e8e0aaf9fa14599178f84374c7f Mon Sep 17 00:00:00 2001 From: Zhang Li Date: Tue, 10 Dec 2024 05:49:19 +0800 Subject: [PATCH 14/26] Improve substr() performance by avoiding using owned string (#13688) Co-authored-by: zhangli20 --- datafusion/functions/src/unicode/substr.rs | 77 +++++++++++----------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 141984cf2674..687f77dbef5b 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -21,8 +21,8 @@ use std::sync::{Arc, OnceLock}; use crate::strings::{make_and_append_view, StringArrayType}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ - Array, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, - StringViewArray, + Array, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; @@ -448,10 +448,9 @@ where match args.len() { 1 => { let iter = ArrayIter::new(string_array); - - let result = iter - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { + let mut result_builder = GenericStringBuilder::::new(); + for (string, start) in iter.zip(start_array.iter()) { + match (string, start) { (Some(string), Some(start)) => { let (start, end) = get_true_start_end( string, @@ -460,47 +459,51 @@ where enable_ascii_fast_path, ); // start, end is byte-based let substr = &string[start..end]; - Some(substr.to_string()) + result_builder.append_value(substr); } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } 2 => { let iter = ArrayIter::new(string_array); let count_array = count_array_opt.unwrap(); + let mut result_builder = GenericStringBuilder::::new(); - let result = iter - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| { - match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + for ((string, start), count) in + iter.zip(start_array.iter()).zip(count_array.iter()) + { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + return exec_err!( "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - if start == i64::MIN { - return exec_err!("negative overflow when calculating skip value"); - } - let (start, end) = get_true_start_end( - string, - start, - Some(count as u64), - enable_ascii_fast_path, - ); // start, end is byte-based - let substr = &string[start..end]; - Ok(Some(substr.to_string())) + ); + } else { + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + result_builder.append_value(substr); } - _ => Ok(None), } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + _ => { + result_builder.append_null(); + } + } + } + Ok(Arc::new(result_builder.finish()) as ArrayRef) } other => { exec_err!("substr was called with {other} arguments. It requires 2 or 3.") From d8c9cfb4be3106eaf506741e2897640fad028d26 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 9 Dec 2024 14:49:58 -0700 Subject: [PATCH 15/26] reinstate down_cast_any_ref (#13705) --- .../physical-expr-common/src/physical_expr.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 93bdcdef8ea0..c2e892d63da0 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -214,6 +214,21 @@ pub fn with_new_children_if_necessary( } } +#[deprecated(since = "44.0.0")] +pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else { + any + } +} + /// Returns [`Display`] able a list of [`PhysicalExpr`] /// /// Example output: `[a + 1, b]` From f8c0efe1cf41f48c25c07cdd135774168d63744c Mon Sep 17 00:00:00 2001 From: Tai Le Manh Date: Tue, 10 Dec 2024 09:33:25 +0700 Subject: [PATCH 16/26] Optimize performance of `character_length` function (#13696) * Optimize performance of function Signed-off-by: Tai Le Manh * Add pre-check array is null * Fix clippy warnings --------- Signed-off-by: Tai Le Manh --- .../functions/src/unicode/character_length.rs | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 822bdca9aca8..ad51a8ef72fb 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -18,7 +18,7 @@ use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; @@ -136,31 +136,52 @@ fn character_length(args: &[ArrayRef]) -> Result { } } -fn character_length_general<'a, T: ArrowPrimitiveType, V: StringArrayType<'a>>( - array: V, -) -> Result +fn character_length_general<'a, T, V>(array: V) -> Result where + T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, + V: StringArrayType<'a>, { + let mut builder = PrimitiveBuilder::::with_capacity(array.len()); + // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - let iter = array.iter(); - let result = iter - .map(|string| { - string.map(|string: &str| { - if is_array_ascii_only { - T::Native::usize_as(string.len()) - } else { - T::Native::usize_as(string.chars().count()) - } - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + if array.null_count() == 0 { + if is_array_ascii_only { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } else { + for i in 0..array.len() { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } else if is_array_ascii_only { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.len())); + } + } + } else { + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let value = array.value(i); + builder.append_value(T::Native::usize_as(value.chars().count())); + } + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] From 2d8bd422ea7e8f8714463022b4ed9cf9efb36564 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 10 Dec 2024 02:21:58 -0800 Subject: [PATCH 17/26] Update prost-build requirement from =0.13.3 to =0.13.4 (#13698) Updates the requirements on [prost-build](https://github.com/tokio-rs/prost) to permit the latest version. - [Release notes](https://github.com/tokio-rs/prost/releases) - [Changelog](https://github.com/tokio-rs/prost/blob/master/CHANGELOG.md) - [Commits](https://github.com/tokio-rs/prost/compare/v0.13.3...v0.13.4) --- updated-dependencies: - dependency-name: prost-build dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index da5bc6029ff9..21fc9eccb40c 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.3" +prost-build = "=0.13.4" diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 297406becada..dda72d20a159 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -35,4 +35,4 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic pbjson-build = "=0.7.0" -prost-build = "=0.13.3" +prost-build = "=0.13.4" From 5dc6e42aa26e7023d9dd72fde3a4a6488c890d23 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Tue, 10 Dec 2024 12:18:39 -0800 Subject: [PATCH 18/26] Minor: Output elapsed time for sql logic test (#13718) * Minor: Output elapsed time for sql logic test --- datafusion/sqllogictest/bin/sqllogictests.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 12c0e27ea911..176bd3229125 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -15,10 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ffi::OsStr; -use std::fs; -use std::path::{Path, PathBuf}; - use clap::Parser; use datafusion_common::utils::get_available_parallelism; use datafusion_sqllogictest::{DataFusion, TestContext}; @@ -26,6 +22,9 @@ use futures::stream::StreamExt; use itertools::Itertools; use log::info; use sqllogictest::strict_column_validator; +use std::ffi::OsStr; +use std::fs; +use std::path::{Path, PathBuf}; use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; @@ -100,7 +99,8 @@ async fn run_tests() -> Result<()> { let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?) .map(|test_file| { SpawnedTask::spawn(async move { - println!("Running {:?}", test_file.relative_path); + let file_path = test_file.relative_path.clone(); + let start = datafusion::common::instant::Instant::now(); if options.complete { run_complete_file(test_file).await?; } else if options.postgres_runner { @@ -108,6 +108,7 @@ async fn run_tests() -> Result<()> { } else { run_test_file(test_file).await?; } + println!("Executed {:?}. Took {:?}", file_path, start.elapsed()); Ok(()) as Result<()> }) .join() From 4fb9d2a0cc08b6177cb99d3a2f0b98e0a49ddb60 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Wed, 11 Dec 2024 09:14:44 +0800 Subject: [PATCH 19/26] refactor: simplify the `make_udf_function` macro (#13712) --- datafusion/functions/src/core/mod.rs | 22 +++++----- datafusion/functions/src/crypto/mod.rs | 12 +++--- datafusion/functions/src/datetime/mod.rs | 54 ++++++++---------------- datafusion/functions/src/encoding/mod.rs | 4 +- datafusion/functions/src/macros.rs | 25 ++++++----- datafusion/functions/src/math/mod.rs | 53 +++++++---------------- datafusion/functions/src/regex/mod.rs | 12 ++---- datafusion/functions/src/string/mod.rs | 44 +++++++++---------- datafusion/functions/src/unicode/mod.rs | 28 ++++++------ 9 files changed, 101 insertions(+), 153 deletions(-) diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 24d26c539539..bd8305cd56d8 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -35,17 +35,17 @@ pub mod r#struct; pub mod version; // create UDFs -make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); -make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); -make_udf_function!(nvl::NVLFunc, NVL, nvl); -make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); -make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); -make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); -make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); -make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); -make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); -make_udf_function!(greatest::GreatestFunc, GREATEST, greatest); -make_udf_function!(version::VersionFunc, VERSION, version); +make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); +make_udf_function!(nullif::NullIfFunc, nullif); +make_udf_function!(nvl::NVLFunc, nvl); +make_udf_function!(nvl2::NVL2Func, nvl2); +make_udf_function!(arrowtypeof::ArrowTypeOfFunc, arrow_typeof); +make_udf_function!(r#struct::StructFunc, r#struct); +make_udf_function!(named_struct::NamedStructFunc, named_struct); +make_udf_function!(getfield::GetFieldFunc, get_field); +make_udf_function!(coalesce::CoalesceFunc, coalesce); +make_udf_function!(greatest::GreatestFunc, greatest); +make_udf_function!(version::VersionFunc, version); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index 46177fc22b60..62ea3c2e2737 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -27,12 +27,12 @@ pub mod sha224; pub mod sha256; pub mod sha384; pub mod sha512; -make_udf_function!(digest::DigestFunc, DIGEST, digest); -make_udf_function!(md5::Md5Func, MD5, md5); -make_udf_function!(sha224::SHA224Func, SHA224, sha224); -make_udf_function!(sha256::SHA256Func, SHA256, sha256); -make_udf_function!(sha384::SHA384Func, SHA384, sha384); -make_udf_function!(sha512::SHA512Func, SHA512, sha512); +make_udf_function!(digest::DigestFunc, digest); +make_udf_function!(md5::Md5Func, md5); +make_udf_function!(sha224::SHA224Func, sha224); +make_udf_function!(sha256::SHA256Func, sha256); +make_udf_function!(sha384::SHA384Func, sha384); +make_udf_function!(sha512::SHA512Func, sha512); pub mod expr_fn { export_functions!(( diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index db4e365267dd..96ca63010ee4 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -37,43 +37,23 @@ pub mod to_timestamp; pub mod to_unixtime; // create UDFs -make_udf_function!(current_date::CurrentDateFunc, CURRENT_DATE, current_date); -make_udf_function!(current_time::CurrentTimeFunc, CURRENT_TIME, current_time); -make_udf_function!(date_bin::DateBinFunc, DATE_BIN, date_bin); -make_udf_function!(date_part::DatePartFunc, DATE_PART, date_part); -make_udf_function!(date_trunc::DateTruncFunc, DATE_TRUNC, date_trunc); -make_udf_function!(make_date::MakeDateFunc, MAKE_DATE, make_date); -make_udf_function!( - from_unixtime::FromUnixtimeFunc, - FROM_UNIXTIME, - from_unixtime -); -make_udf_function!(now::NowFunc, NOW, now); -make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); -make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); -make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); -make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); -make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); -make_udf_function!( - to_timestamp::ToTimestampSecondsFunc, - TO_TIMESTAMP_SECONDS, - to_timestamp_seconds -); -make_udf_function!( - to_timestamp::ToTimestampMillisFunc, - TO_TIMESTAMP_MILLIS, - to_timestamp_millis -); -make_udf_function!( - to_timestamp::ToTimestampMicrosFunc, - TO_TIMESTAMP_MICROS, - to_timestamp_micros -); -make_udf_function!( - to_timestamp::ToTimestampNanosFunc, - TO_TIMESTAMP_NANOS, - to_timestamp_nanos -); +make_udf_function!(current_date::CurrentDateFunc, current_date); +make_udf_function!(current_time::CurrentTimeFunc, current_time); +make_udf_function!(date_bin::DateBinFunc, date_bin); +make_udf_function!(date_part::DatePartFunc, date_part); +make_udf_function!(date_trunc::DateTruncFunc, date_trunc); +make_udf_function!(make_date::MakeDateFunc, make_date); +make_udf_function!(from_unixtime::FromUnixtimeFunc, from_unixtime); +make_udf_function!(now::NowFunc, now); +make_udf_function!(to_char::ToCharFunc, to_char); +make_udf_function!(to_date::ToDateFunc, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, to_local_time); +make_udf_function!(to_unixtime::ToUnixtimeFunc, to_unixtime); +make_udf_function!(to_timestamp::ToTimestampFunc, to_timestamp); +make_udf_function!(to_timestamp::ToTimestampSecondsFunc, to_timestamp_seconds); +make_udf_function!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); +make_udf_function!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); +make_udf_function!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); // we cannot currently use the export_functions macro since it doesn't handle // functions with varargs currently diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs index 48171370ad58..b0ddbd368a6b 100644 --- a/datafusion/functions/src/encoding/mod.rs +++ b/datafusion/functions/src/encoding/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; pub mod inner; // create `encode` and `decode` UDFs -make_udf_function!(inner::EncodeFunc, ENCODE, encode); -make_udf_function!(inner::DecodeFunc, DECODE, decode); +make_udf_function!(inner::EncodeFunc, encode); +make_udf_function!(inner::DecodeFunc, decode); // Export the functions out of this package, both as expr_fn as well as a list of functions pub mod expr_fn { diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index bedec9bb2e6f..82308601490c 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -65,24 +65,23 @@ macro_rules! export_functions { }; } -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that singleton. +/// Creates a singleton `ScalarUDF` of the `$UDF` function and a function +/// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. macro_rules! make_udf_function { - ($UDF:ty, $GNAME:ident, $NAME:ident) => { - #[doc = "Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation "] - #[doc = stringify!($UDF)] + ($UDF:ty, $NAME:ident) => { + #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { // Singleton instance of the function - static $GNAME: std::sync::LazyLock< + static INSTANCE: std::sync::LazyLock< std::sync::Arc, > = std::sync::LazyLock::new(|| { std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( <$UDF>::new(), )) }); - std::sync::Arc::clone(&$GNAME) + std::sync::Arc::clone(&INSTANCE) } }; } @@ -134,13 +133,13 @@ macro_rules! downcast_arg { /// applies a unary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; @@ -248,13 +247,13 @@ macro_rules! make_math_unary_udf { /// applies a binary floating function to the argument, and returns a value of the same type. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` -/// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $BINARY_FUNC: the binary function to apply to the argument /// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_binary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { - make_udf_function!($NAME::$UDF, $GNAME, $NAME); + ($UDF:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { + make_udf_function!($NAME::$UDF, $NAME); mod $NAME { use std::any::Any; diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 1452bfdee5a0..4eb337a30110 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -40,10 +40,9 @@ pub mod signum; pub mod trunc; // Create UDFs -make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(abs::AbsFunc, abs); make_math_unary_udf!( AcosFunc, - ACOS, acos, acos, super::acos_order, @@ -52,7 +51,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AcoshFunc, - ACOSH, acosh, acosh, super::acosh_order, @@ -61,7 +59,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AsinFunc, - ASIN, asin, asin, super::asin_order, @@ -70,7 +67,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AsinhFunc, - ASINH, asinh, asinh, super::asinh_order, @@ -79,7 +75,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AtanFunc, - ATAN, atan, atan, super::atan_order, @@ -88,7 +83,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( AtanhFunc, - ATANH, atanh, atanh, super::atanh_order, @@ -97,7 +91,6 @@ make_math_unary_udf!( ); make_math_binary_udf!( Atan2, - ATAN2, atan2, atan2, super::atan2_order, @@ -105,7 +98,6 @@ make_math_binary_udf!( ); make_math_unary_udf!( CbrtFunc, - CBRT, cbrt, cbrt, super::cbrt_order, @@ -114,7 +106,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( CeilFunc, - CEIL, ceil, ceil, super::ceil_order, @@ -123,7 +114,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( CosFunc, - COS, cos, cos, super::cos_order, @@ -132,17 +122,15 @@ make_math_unary_udf!( ); make_math_unary_udf!( CoshFunc, - COSH, cosh, cosh, super::cosh_order, super::bounds::cosh_bounds, super::get_cosh_doc ); -make_udf_function!(cot::CotFunc, COT, cot); +make_udf_function!(cot::CotFunc, cot); make_math_unary_udf!( DegreesFunc, - DEGREES, degrees, to_degrees, super::degrees_order, @@ -151,31 +139,28 @@ make_math_unary_udf!( ); make_math_unary_udf!( ExpFunc, - EXP, exp, exp, super::exp_order, super::bounds::exp_bounds, super::get_exp_doc ); -make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); +make_udf_function!(factorial::FactorialFunc, factorial); make_math_unary_udf!( FloorFunc, - FLOOR, floor, floor, super::floor_order, super::bounds::unbounded_bounds, super::get_floor_doc ); -make_udf_function!(log::LogFunc, LOG, log); -make_udf_function!(gcd::GcdFunc, GCD, gcd); -make_udf_function!(nans::IsNanFunc, ISNAN, isnan); -make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); -make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_udf_function!(log::LogFunc, log); +make_udf_function!(gcd::GcdFunc, gcd); +make_udf_function!(nans::IsNanFunc, isnan); +make_udf_function!(iszero::IsZeroFunc, iszero); +make_udf_function!(lcm::LcmFunc, lcm); make_math_unary_udf!( LnFunc, - LN, ln, ln, super::ln_order, @@ -184,7 +169,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( Log2Func, - LOG2, log2, log2, super::log2_order, @@ -193,31 +177,28 @@ make_math_unary_udf!( ); make_math_unary_udf!( Log10Func, - LOG10, log10, log10, super::log10_order, super::bounds::unbounded_bounds, super::get_log10_doc ); -make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); -make_udf_function!(pi::PiFunc, PI, pi); -make_udf_function!(power::PowerFunc, POWER, power); +make_udf_function!(nanvl::NanvlFunc, nanvl); +make_udf_function!(pi::PiFunc, pi); +make_udf_function!(power::PowerFunc, power); make_math_unary_udf!( RadiansFunc, - RADIANS, radians, to_radians, super::radians_order, super::bounds::radians_bounds, super::get_radians_doc ); -make_udf_function!(random::RandomFunc, RANDOM, random); -make_udf_function!(round::RoundFunc, ROUND, round); -make_udf_function!(signum::SignumFunc, SIGNUM, signum); +make_udf_function!(random::RandomFunc, random); +make_udf_function!(round::RoundFunc, round); +make_udf_function!(signum::SignumFunc, signum); make_math_unary_udf!( SinFunc, - SIN, sin, sin, super::sin_order, @@ -226,7 +207,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( SinhFunc, - SINH, sinh, sinh, super::sinh_order, @@ -235,7 +215,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( SqrtFunc, - SQRT, sqrt, sqrt, super::sqrt_order, @@ -244,7 +223,6 @@ make_math_unary_udf!( ); make_math_unary_udf!( TanFunc, - TAN, tan, tan, super::tan_order, @@ -253,14 +231,13 @@ make_math_unary_udf!( ); make_math_unary_udf!( TanhFunc, - TANH, tanh, tanh, super::tanh_order, super::bounds::tanh_bounds, super::get_tanh_doc ); -make_udf_function!(trunc::TruncFunc, TRUNC, trunc); +make_udf_function!(trunc::TruncFunc, trunc); pub mod expr_fn { export_functions!( diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 803f51e915a9..13fbc049af58 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -25,14 +25,10 @@ pub mod regexpmatch; pub mod regexpreplace; // create UDFs -make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); -make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); -make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); -make_udf_function!( - regexpreplace::RegexpReplaceFunc, - REGEXP_REPLACE, - regexp_replace -); +make_udf_function!(regexpcount::RegexpCountFunc, regexp_count); +make_udf_function!(regexpmatch::RegexpMatchFunc, regexp_match); +make_udf_function!(regexplike::RegexpLikeFunc, regexp_like); +make_udf_function!(regexpreplace::RegexpReplaceFunc, regexp_replace); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 622802f0142b..f156f070d960 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -45,28 +45,28 @@ pub mod to_hex; pub mod upper; pub mod uuid; // create UDFs -make_udf_function!(ascii::AsciiFunc, ASCII, ascii); -make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); -make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); -make_udf_function!(chr::ChrFunc, CHR, chr); -make_udf_function!(concat::ConcatFunc, CONCAT, concat); -make_udf_function!(concat_ws::ConcatWsFunc, CONCAT_WS, concat_ws); -make_udf_function!(ends_with::EndsWithFunc, ENDS_WITH, ends_with); -make_udf_function!(initcap::InitcapFunc, INITCAP, initcap); -make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); -make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); -make_udf_function!(lower::LowerFunc, LOWER, lower); -make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); -make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); -make_udf_function!(repeat::RepeatFunc, REPEAT, repeat); -make_udf_function!(replace::ReplaceFunc, REPLACE, replace); -make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); -make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); -make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); -make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); -make_udf_function!(upper::UpperFunc, UPPER, upper); -make_udf_function!(uuid::UuidFunc, UUID, uuid); -make_udf_function!(contains::ContainsFunc, CONTAINS, contains); +make_udf_function!(ascii::AsciiFunc, ascii); +make_udf_function!(bit_length::BitLengthFunc, bit_length); +make_udf_function!(btrim::BTrimFunc, btrim); +make_udf_function!(chr::ChrFunc, chr); +make_udf_function!(concat::ConcatFunc, concat); +make_udf_function!(concat_ws::ConcatWsFunc, concat_ws); +make_udf_function!(ends_with::EndsWithFunc, ends_with); +make_udf_function!(initcap::InitcapFunc, initcap); +make_udf_function!(levenshtein::LevenshteinFunc, levenshtein); +make_udf_function!(ltrim::LtrimFunc, ltrim); +make_udf_function!(lower::LowerFunc, lower); +make_udf_function!(octet_length::OctetLengthFunc, octet_length); +make_udf_function!(overlay::OverlayFunc, overlay); +make_udf_function!(repeat::RepeatFunc, repeat); +make_udf_function!(replace::ReplaceFunc, replace); +make_udf_function!(rtrim::RtrimFunc, rtrim); +make_udf_function!(starts_with::StartsWithFunc, starts_with); +make_udf_function!(split_part::SplitPartFunc, split_part); +make_udf_function!(to_hex::ToHexFunc, to_hex); +make_udf_function!(upper::UpperFunc, upper); +make_udf_function!(uuid::UuidFunc, uuid); +make_udf_function!(contains::ContainsFunc, contains); pub mod expr_fn { use datafusion_expr::Expr; diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 40915bc9efde..f31ece9196d8 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -34,22 +34,18 @@ pub mod substrindex; pub mod translate; // create UDFs -make_udf_function!( - character_length::CharacterLengthFunc, - CHARACTER_LENGTH, - character_length -); -make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); -make_udf_function!(left::LeftFunc, LEFT, left); -make_udf_function!(lpad::LPadFunc, LPAD, lpad); -make_udf_function!(right::RightFunc, RIGHT, right); -make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); -make_udf_function!(rpad::RPadFunc, RPAD, rpad); -make_udf_function!(strpos::StrposFunc, STRPOS, strpos); -make_udf_function!(substr::SubstrFunc, SUBSTR, substr); -make_udf_function!(substr::SubstrFunc, SUBSTRING, substring); -make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); -make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); +make_udf_function!(character_length::CharacterLengthFunc, character_length); +make_udf_function!(find_in_set::FindInSetFunc, find_in_set); +make_udf_function!(left::LeftFunc, left); +make_udf_function!(lpad::LPadFunc, lpad); +make_udf_function!(right::RightFunc, right); +make_udf_function!(reverse::ReverseFunc, reverse); +make_udf_function!(rpad::RPadFunc, rpad); +make_udf_function!(strpos::StrposFunc, strpos); +make_udf_function!(substr::SubstrFunc, substr); +make_udf_function!(substr::SubstrFunc, substring); +make_udf_function!(substrindex::SubstrIndexFunc, substr_index); +make_udf_function!(translate::TranslateFunc, translate); pub mod expr_fn { use datafusion_expr::Expr; From fa0440b3193767f9214cdceeb08a1016652e768d Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 11 Dec 2024 20:02:10 +0800 Subject: [PATCH 20/26] refactor: replace `Vec` with `IndexMap` for expression mappings in `ProjectionMapping` and `EquivalenceGroup` (#13675) * refactor: replace Vec with IndexMap for expression mappings in ProjectionMapping and EquivalenceGroup * chore * chore: Fix CI * chore: comment * chore: simplify --- .../physical-expr/src/equivalence/class.rs | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index d06a495d970a..cc26d12fb029 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -17,8 +17,8 @@ use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; use crate::{ - expressions::Column, physical_exprs_contains, LexOrdering, LexRequirement, - PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, + expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, PhysicalSortRequirement, }; use std::fmt::Display; use std::sync::Arc; @@ -27,7 +27,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; -use indexmap::IndexSet; +use indexmap::{IndexMap, IndexSet}; /// A structure representing a expression known to be constant in a physical execution plan. /// @@ -546,28 +546,20 @@ impl EquivalenceGroup { .collect::>(); (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) }); - // TODO: Convert the algorithm below to a version that uses `HashMap`. - // once `Arc` can be stored in `HashMap`. - // See issue: https://github.com/apache/datafusion/issues/8027 - let mut new_classes = vec![]; - for (source, target) in mapping.iter() { - if new_classes.is_empty() { - new_classes.push((source, vec![Arc::clone(target)])); - } - if let Some((_, values)) = - new_classes.iter_mut().find(|(key, _)| *key == source) - { - if !physical_exprs_contains(values, target) { - values.push(Arc::clone(target)); - } - } - } + // the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression. + let mut new_classes: IndexMap, EquivalenceClass> = + IndexMap::new(); + mapping.iter().for_each(|(source, target)| { + new_classes + .entry(Arc::clone(source)) + .or_insert_with(EquivalenceClass::new_empty) + .push(Arc::clone(target)); + }); // Only add equivalence classes with at least two members as singleton // equivalence classes are meaningless. let new_classes = new_classes .into_iter() - .filter_map(|(_, values)| (values.len() > 1).then_some(values)) - .map(EquivalenceClass::new); + .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); let classes = projected_classes.chain(new_classes).collect(); Self::new(classes) From d3c459e797d066e1e5daa95d72679fe2186404cb Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Wed, 11 Dec 2024 20:21:17 +0800 Subject: [PATCH 21/26] Handle alias when parsing sql(parse_sql_expr) (#12939) * fix: Fix parse_sql_expr not handling alias * cargo fmt * fix parse_sql_expr example(remove alias) * add testing * add SUM udaf to TestContextProvider and modify test_sql_to_expr_with_alias for function * revert change on example `parse_sql_expr` --- .../examples/parse_sql_expr.rs | 10 ++-- .../core/src/execution/session_state.rs | 21 +++++-- datafusion/sql/src/expr/mod.rs | 60 +++++++++++++++++-- datafusion/sql/src/parser.rs | 9 +-- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index e23e5accae39..d8f0778e19e3 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -121,11 +121,11 @@ async fn query_parquet_demo() -> Result<()> { assert_batches_eq!( &[ - "+------------+----------------------+", - "| double_col | sum(?table?.int_col) |", - "+------------+----------------------+", - "| 10.1 | 4 |", - "+------------+----------------------+", + "+------------+-------------+", + "| double_col | sum_int_col |", + "+------------+-------------+", + "| 10.1 | 4 |", + "+------------+-------------+", ], &result ); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 4ccad5ffd323..cef5d4c1ee2a 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -68,7 +68,7 @@ use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, Sq use itertools::Itertools; use log::{debug, info}; use object_store::ObjectStore; -use sqlparser::ast::Expr as SQLExpr; +use sqlparser::ast::{Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias}; use sqlparser::dialect::dialect_from_str; use std::any::Any; use std::collections::hash_map::Entry; @@ -500,11 +500,22 @@ impl SessionState { sql: &str, dialect: &str, ) -> datafusion_common::Result { + self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) + } + + /// parse a sql string into a sqlparser-rs AST [`SQLExprWithAlias`]. + /// + /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + pub fn sql_to_expr_with_alias( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ - MsSQL, ClickHouse, BigQuery, Ansi." + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." ) })?; @@ -603,7 +614,7 @@ impl SessionState { ) -> datafusion_common::Result { let dialect = self.config.options().sql_parser.dialect.as_str(); - let sql_expr = self.sql_to_expr(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; let provider = SessionContextProvider { state: self, @@ -611,7 +622,7 @@ impl SessionState { }; let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); - query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) + query.sql_to_expr_with_alias(sql_expr, df_schema, &mut PlannerContext::new()) } /// Returns the [`Analyzer`] for this session diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 57ac96951f1f..e8ec8d7b7d1c 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -23,7 +23,8 @@ use datafusion_expr::planner::{ use recursive::recursive; use sqlparser::ast::{ BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, - Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value, + Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, + TrimWhereField, Value, }; use datafusion_common::{ @@ -50,6 +51,19 @@ mod unary_op; mod value; impl SqlToRel<'_, S> { + pub(crate) fn sql_expr_to_logical_expr_with_alias( + &self, + sql: SQLExprWithAlias, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut expr = + self.sql_expr_to_logical_expr(sql.expr, schema, planner_context)?; + if let Some(alias) = sql.alias { + expr = expr.alias(alias.value); + } + Ok(expr) + } pub(crate) fn sql_expr_to_logical_expr( &self, sql: SQLExpr, @@ -131,6 +145,20 @@ impl SqlToRel<'_, S> { ))) } + pub fn sql_to_expr_with_alias( + &self, + sql: SQLExprWithAlias, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut expr = + self.sql_expr_to_logical_expr_with_alias(sql, schema, planner_context)?; + expr = self.rewrite_partial_qualifier(expr, schema); + self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; + let (expr, _) = expr.infer_placeholder_types(schema)?; + Ok(expr) + } + /// Generate a relational expression from a SQL expression pub fn sql_to_expr( &self, @@ -1091,8 +1119,11 @@ mod tests { None } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None + fn get_aggregate_meta(&self, name: &str) -> Option> { + match name { + "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), + _ => None, + } } fn get_variable_type(&self, _variable_names: &[String]) -> Option { @@ -1112,7 +1143,7 @@ mod tests { } fn udaf_names(&self) -> Vec { - Vec::new() + vec!["sum".to_string()] } fn udwf_names(&self) -> Vec { @@ -1167,4 +1198,25 @@ mod tests { test_stack_overflow!(2048); test_stack_overflow!(4096); test_stack_overflow!(8192); + #[test] + fn test_sql_to_expr_with_alias() { + let schema = DFSchema::empty(); + let mut planner_context = PlannerContext::default(); + + let expr_str = "SUM(int_col) as sum_int_col"; + + let dialect = GenericDialect {}; + let mut parser = Parser::new(&dialect).try_with_sql(expr_str).unwrap(); + // from sqlparser + let sql_expr = parser.parse_expr_with_alias().unwrap(); + + let context_provider = TestContextProvider::new(); + let sql_to_rel = SqlToRel::new(&context_provider); + + let expr = sql_to_rel + .sql_expr_to_logical_expr_with_alias(sql_expr, &schema, &mut planner_context) + .unwrap(); + + assert!(matches!(expr, Expr::Alias(_))); + } } diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index bd1ed3145ef5..efec6020641c 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -20,9 +20,10 @@ use std::collections::VecDeque; use std::fmt; +use sqlparser::ast::ExprWithAlias; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query, + ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, @@ -328,7 +329,7 @@ impl<'a> DFParser<'a> { pub fn parse_sql_into_expr_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result { + ) -> Result { let mut parser = DFParser::new_with_dialect(sql, dialect)?; parser.parse_expr() } @@ -377,7 +378,7 @@ impl<'a> DFParser<'a> { } } - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> Result { if let Token::Word(w) = self.parser.peek_token().token { match w.keyword { Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { @@ -387,7 +388,7 @@ impl<'a> DFParser<'a> { } } - self.parser.parse_expr() + self.parser.parse_expr_with_alias() } /// Parse a SQL `COPY TO` statement From ddfc9e5f95c468d845b65bf4f05a8fcb08366914 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 11 Dec 2024 07:27:46 -0500 Subject: [PATCH 22/26] Improve documentation for TableProvider (#13724) --- datafusion/catalog/src/table.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index b6752191d9a7..3c8960495588 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -33,7 +33,19 @@ use datafusion_expr::{ }; use datafusion_physical_plan::ExecutionPlan; -/// Source table +/// A named table which can be queried. +/// +/// Please see [`CatalogProvider`] for details of implementing a custom catalog. +/// +/// [`TableProvider`] represents a source of data which can provide data as +/// Apache Arrow `RecordBatch`es. Implementations of this trait provide +/// important information for planning such as: +/// +/// 1. [`Self::schema`]: The schema (columns and their types) of the table +/// 2. [`Self::supports_filters_pushdown`]: Should filters be pushed into this scan +/// 2. [`Self::scan`]: An [`ExecutionPlan`] that can read data +/// +/// [`CatalogProvider`]: super::CatalogProvider #[async_trait] pub trait TableProvider: Debug + Sync + Send { /// Returns the table provider as [`Any`](std::any::Any) so that it can be From b494157b0d6750b51c8f5811629dab090d1a43c0 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 11 Dec 2024 14:29:23 +0100 Subject: [PATCH 23/26] Reveal implementing type and return type in simple UDF implementations (#13730) Debug trait is useful for understanding what something is and how it's configured, especially if the implementation is behind dyn trait. --- datafusion/expr/src/expr_fn.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 681eb3c0afd5..a44dd24039dc 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -416,9 +416,10 @@ pub struct SimpleScalarUDF { impl Debug for SimpleScalarUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("ScalarUDF") + f.debug_struct("SimpleScalarUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } @@ -524,9 +525,10 @@ pub struct SimpleAggregateUDF { impl Debug for SimpleAggregateUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("AggregateUDF") + f.debug_struct("SimpleAggregateUDF") .field("name", &self.name) .field("signature", &self.signature) + .field("return_type", &self.return_type) .field("fun", &"") .finish() } From 3b5daa2ac4659a3b590285752fe734d1c07315e8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 11 Dec 2024 09:49:23 -0500 Subject: [PATCH 24/26] minor: Extract tests for `EXTRACT` AND `date_part` to their own file (#13731) --- datafusion/sqllogictest/test_files/expr.slt | 861 ----------------- .../test_files/expr/date_part.slt | 878 ++++++++++++++++++ 2 files changed, 878 insertions(+), 861 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/expr/date_part.slt diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 9b8dfc2186be..2306eda77d35 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -832,867 +832,6 @@ SELECT ---- 0 NULL 0 NULL -# test_extract_date_part - -query error -SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') - -query error -SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') - -query I -SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) ----- -2000 - -query I -SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') ----- -2020 - -query I -SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -3 - -query I -SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -9 - -query I -SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -37 - -query I -SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -8 - -query I -SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) ----- -1 - -query I -SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -252 - -query I -SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) ----- -6 - -query I -SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) ----- -2 - -query I -SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) ----- -0 - -query I -SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) ----- -12 - -query I -SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -query I -SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) ----- -12 - -# make sure the return type is integer -query T -SELECT arrow_typeof(date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))) ----- -Int32 - -query I -SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - -query I -SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - -query I -SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') - - -# Keep precision when coercing Utf8 to Timestamp -query I -SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') - - -query I -SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') ----- -12 - -query I -SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') ----- -12123 - -query I -SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') ----- -12123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') - -# test_date_part_time - -## time32 seconds -query I -SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000000 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -50000000 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -84770 - -query R -SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) ----- -84770 - -## time32 milliseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123000 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -50123000 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -84770.123 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) ----- -84770.123 - -## time64 microseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50 - -query I -SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123 - -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123456 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -50123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -84770.123456 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) ----- -84770.123456 - -## time64 nanoseconds -query I -SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -23 - -query I -SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -23 - -query I -SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -32 - -query I -SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -32 - -query I -SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50 - -query I -select extract(second from '2024-08-09T12:13:14') ----- -14 - -query I -select extract(seconds from '2024-08-09T12:13:14') ----- -14 - -query I -SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50 - -query I -SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123 - -query I -SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123 - -# just some floating point stuff happening in the result here -query I -SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query I -SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query I -SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -50123456 - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) - -query R -SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -84770.123456789 - -query R -SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ----- -84770.123456789 - -# test_extract_epoch - -query R -SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) ----- --3155646649.744 - -query R -SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) ----- -946684800 - -query R -SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) ----- -946684800 - -query R -SELECT extract(epoch from NULL::timestamp) ----- -NULL - -query R -SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) ----- -0 - -query R -SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) ----- -86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) ----- -864000 - -query R -SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) ----- --86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) ----- -0 - -query R -SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) ----- -86400 - -query R -SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) ----- -864000 - -query R -SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) ----- --86400 - -# test_extract_interval - -query I -SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) ----- -10 - -query I -SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) ----- -0 - -query I -SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) ----- -0 - -query I -SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) ----- -10 - -query I -SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) ----- -1 - -query I -SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) ----- -8 - -query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) -SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) - -query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) -SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) - -query I -SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) ----- -10 - -query I -SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -0 - -query I -SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -14400 - -query I -SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) ----- -5 - -query I -SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) ----- -864000 - -query I -SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) ----- -0 - -query I -SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) ----- -0 - -query I -SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query R -SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) ----- -2000 - -query I -SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) ----- -2 - -query I -SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) ----- -NULL - -statement ok -create table t (id int, i interval) as values - (0, interval '5 months 1 day 10 nanoseconds'), - (1, interval '1 year 3 months'), - (2, interval '3 days 2 milliseconds'), - (3, interval '2 seconds'), - (4, interval '8 months'), - (5, NULL); - -query III -select - id, - extract(second from i), - extract(month from i) -from t -order by id; ----- -0 0 5 -1 0 15 -2 0 0 -3 2 0 -4 0 8 -5 NULL NULL - -statement ok -drop table t; - -# test_extract_duration - -query I -SELECT extract(second from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query I -SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query R -SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) ----- -2 - -query I -SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) ----- -2000 - -query I -SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) ----- -0 - -query I -SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) ----- -2 - -query I -SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) ----- -2002 - -query I -SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) ----- -10 - -query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) -SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) - -query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) -SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) - -query I -SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) ----- -NULL - -# test_extract_date_part_func - -query B -SELECT (date_part('year', now()) = EXTRACT(year FROM now())) ----- -true - -query B -SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) ----- -true - -query B -SELECT (date_part('month', now()) = EXTRACT(month FROM now())) ----- -true - -query B -SELECT (date_part('week', now()) = EXTRACT(week FROM now())) ----- -true - -query B -SELECT (date_part('day', now()) = EXTRACT(day FROM now())) ----- -true - -query B -SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) ----- -true - -query B -SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) ----- -true - -query B -SELECT (date_part('second', now()) = EXTRACT(second FROM now())) ----- -true - -query B -SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) ----- -true - -query B -SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) ----- -true - -query error DataFusion error: Internal error: unit Nanosecond not supported -SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) - query B SELECT 'a' IN ('a','b') ---- diff --git a/datafusion/sqllogictest/test_files/expr/date_part.slt b/datafusion/sqllogictest/test_files/expr/date_part.slt new file mode 100644 index 000000000000..cec80a165f30 --- /dev/null +++ b/datafusion/sqllogictest/test_files/expr/date_part.slt @@ -0,0 +1,878 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for `date_part` and `EXTRACT` (which is a different syntax +# for the same function). + +query error +SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') + +query error +SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') + +query I +SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) +---- +2000 + +query I +SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query I +SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query I +SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query I +SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query I +SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query I +SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) +---- +1 + +query I +SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query I +SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) +---- +6 + +query I +SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query I +SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) +---- +0 + +query I +SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query I +SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query I +SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +# make sure the return type is integer +query T +SELECT arrow_typeof(date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00'))) +---- +Int32 + +query I +SELECT EXTRACT(second FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT(millisecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT(microsecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + +query I +SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + +query I +SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') + + +# Keep precision when coercing Utf8 to Timestamp +query I +SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT date_part('millisecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT date_part('microsecond', timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT date_part('nanosecond', timestamp '2020-09-08T12:00:12.12345678+00:00') + + +query I +SELECT date_part('second', '2020-09-08T12:00:12.12345678+00:00') +---- +12 + +query I +SELECT date_part('millisecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123 + +query I +SELECT date_part('microsecond', '2020-09-08T12:00:12.12345678+00:00') +---- +12123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') + +# test_date_part_time + +## time32 seconds +query I +SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +query R +SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +## time32 milliseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +## time64 microseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50 + +query I +SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123 + +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +## time64 nanoseconds +query I +SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query I +SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query I +SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query I +SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query I +SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50 + +query I +select extract(second from '2024-08-09T12:13:14') +---- +14 + +query I +select extract(seconds from '2024-08-09T12:13:14') +---- +14 + +query I +SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50 + +query I +SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123 + +query I +SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123 + +# just some floating point stuff happening in the result here +query I +SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query I +SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query I +SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456 + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +# test_extract_epoch + +query R +SELECT extract(epoch from '1870-01-01T07:29:10.256'::timestamp) +---- +-3155646649.744 + +query R +SELECT extract(epoch from '2000-01-01T00:00:00.000'::timestamp) +---- +946684800 + +query R +SELECT extract(epoch from to_timestamp('2000-01-01T00:00:00+00:00')) +---- +946684800 + +query R +SELECT extract(epoch from NULL::timestamp) +---- +NULL + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date32')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date32')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date32')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date32')) +---- +-86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-01', 'Date64')) +---- +0 + +query R +SELECT extract(epoch from arrow_cast('1970-01-02', 'Date64')) +---- +86400 + +query R +SELECT extract(epoch from arrow_cast('1970-01-11', 'Date64')) +---- +864000 + +query R +SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) +---- +-86400 + +# test_extract_interval + +query I +SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +10 + +query I +SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +0 + +query I +SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +0 + +query I +SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +10 + +query I +SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +1 + +query I +SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +8 + +query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) +SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) + +query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) +SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) + +query I +SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) +---- +10 + +query I +SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +0 + +query I +SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +14400 + +query I +SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) +---- +5 + +query I +SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +864000 + +query I +SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) +---- +0 + +query I +SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) +---- +0 + +query I +SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2000 + +query I +SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) +---- +2 + +query I +SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) +---- +NULL + +statement ok +create table t (id int, i interval) as values + (0, interval '5 months 1 day 10 nanoseconds'), + (1, interval '1 year 3 months'), + (2, interval '3 days 2 milliseconds'), + (3, interval '2 seconds'), + (4, interval '8 months'), + (5, NULL); + +query III +select + id, + extract(second from i), + extract(month from i) +from t +order by id; +---- +0 0 5 +1 0 15 +2 0 0 +3 2 0 +4 0 8 +5 NULL NULL + +statement ok +drop table t; + +# test_extract_duration + +query I +SELECT extract(second from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query I +SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query I +SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) +---- +2000 + +query I +SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) +---- +0 + +query I +SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2 + +query I +SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2002 + +query I +SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) +---- +10 + +query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) +SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) + +query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) +SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) + +query I +SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) +---- +NULL + +# test_extract_date_part_func + +query B +SELECT (date_part('year', now()) = EXTRACT(year FROM now())) +---- +true + +query B +SELECT (date_part('quarter', now()) = EXTRACT(quarter FROM now())) +---- +true + +query B +SELECT (date_part('month', now()) = EXTRACT(month FROM now())) +---- +true + +query B +SELECT (date_part('week', now()) = EXTRACT(week FROM now())) +---- +true + +query B +SELECT (date_part('day', now()) = EXTRACT(day FROM now())) +---- +true + +query B +SELECT (date_part('hour', now()) = EXTRACT(hour FROM now())) +---- +true + +query B +SELECT (date_part('minute', now()) = EXTRACT(minute FROM now())) +---- +true + +query B +SELECT (date_part('second', now()) = EXTRACT(second FROM now())) +---- +true + +query B +SELECT (date_part('millisecond', now()) = EXTRACT(millisecond FROM now())) +---- +true + +query B +SELECT (date_part('microsecond', now()) = EXTRACT(microsecond FROM now())) +---- +true + +query error DataFusion error: Internal error: unit Nanosecond not supported +SELECT (date_part('nanosecond', now()) = EXTRACT(nanosecond FROM now())) From 50ce88385c213be0ea90b5d205cc737db7a01ee3 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 11 Dec 2024 23:27:31 +0800 Subject: [PATCH 25/26] Support unparsing `UNNEST` plan to `UNNEST` table factor SQL (#13660) * add `unnest_as_table_factor` and `UnnestRelationBuilder` * unparse unnest as table factor * fix typo * add tests for the default configs * add a static const for unnest_placeholder * fix tests * fix tests --- datafusion/sql/src/unparser/ast.rs | 73 ++++++++++++++ datafusion/sql/src/unparser/dialect.rs | 23 +++++ datafusion/sql/src/unparser/plan.rs | 55 ++++++++++- datafusion/sql/src/unparser/utils.rs | 2 +- datafusion/sql/src/utils.rs | 50 +++++----- datafusion/sql/tests/cases/plan_to_sql.rs | 99 ++++++++++++++++++- .../sqllogictest/test_files/encoding.slt | 2 +- datafusion/sqllogictest/test_files/joins.slt | 12 +-- .../test_files/push_down_filter.slt | 40 ++++---- .../test_files/table_functions.slt | 2 +- datafusion/sqllogictest/test_files/unnest.slt | 28 +++--- datafusion/sqllogictest/test_files/window.slt | 1 - 12 files changed, 313 insertions(+), 74 deletions(-) diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index cc0812cd71e1..ad0b5f16b283 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -353,6 +353,7 @@ pub(super) struct RelationBuilder { enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), + Unnest(UnnestRelationBuilder), Empty, } @@ -369,6 +370,12 @@ impl RelationBuilder { self.relation = Some(TableFactorBuilder::Derived(value)); self } + + pub fn unnest(&mut self, value: UnnestRelationBuilder) -> &mut Self { + self.relation = Some(TableFactorBuilder::Unnest(value)); + self + } + pub fn empty(&mut self) -> &mut Self { self.relation = Some(TableFactorBuilder::Empty); self @@ -382,6 +389,9 @@ impl RelationBuilder { Some(TableFactorBuilder::Derived(ref mut rel_builder)) => { rel_builder.alias = value; } + Some(TableFactorBuilder::Unnest(ref mut rel_builder)) => { + rel_builder.alias = value; + } Some(TableFactorBuilder::Empty) => (), None => (), } @@ -391,6 +401,7 @@ impl RelationBuilder { Ok(match self.relation { Some(TableFactorBuilder::Table(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Derived(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::Unnest(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Empty) => None, None => return Err(Into::into(UninitializedFieldError::from("relation"))), }) @@ -526,6 +537,68 @@ impl Default for DerivedRelationBuilder { } } +#[derive(Clone)] +pub(super) struct UnnestRelationBuilder { + pub alias: Option, + pub array_exprs: Vec, + with_offset: bool, + with_offset_alias: Option, + with_ordinality: bool, +} + +#[allow(dead_code)] +impl UnnestRelationBuilder { + pub fn alias(&mut self, value: Option) -> &mut Self { + self.alias = value; + self + } + pub fn array_exprs(&mut self, value: Vec) -> &mut Self { + self.array_exprs = value; + self + } + + pub fn with_offset(&mut self, value: bool) -> &mut Self { + self.with_offset = value; + self + } + + pub fn with_offset_alias(&mut self, value: Option) -> &mut Self { + self.with_offset_alias = value; + self + } + + pub fn with_ordinality(&mut self, value: bool) -> &mut Self { + self.with_ordinality = value; + self + } + + pub fn build(&self) -> Result { + Ok(ast::TableFactor::UNNEST { + alias: self.alias.clone(), + array_exprs: self.array_exprs.clone(), + with_offset: self.with_offset, + with_offset_alias: self.with_offset_alias.clone(), + with_ordinality: self.with_ordinality, + }) + } + + fn create_empty() -> Self { + Self { + alias: Default::default(), + array_exprs: Default::default(), + with_offset: Default::default(), + with_offset_alias: Default::default(), + with_ordinality: Default::default(), + } + } +} + +impl Default for UnnestRelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} + /// Runtime error when a `build()` method is called and one or more required fields /// do not have a value. #[derive(Debug, Clone)] diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index e979d8fd4ebd..ae387d441fa2 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -157,6 +157,15 @@ pub trait Dialect: Send + Sync { fn full_qualified_col(&self) -> bool { false } + + /// Allow to unparse the unnest plan as [ast::TableFactor::UNNEST]. + /// + /// Some dialects like BigQuery require UNNEST to be used in the FROM clause but + /// the LogicalPlan planner always puts UNNEST in the SELECT clause. This flag allows + /// to unparse the UNNEST plan as [ast::TableFactor::UNNEST] instead of a subquery. + fn unnest_as_table_factor(&self) -> bool { + false + } } /// `IntervalStyle` to use for unparsing @@ -448,6 +457,7 @@ pub struct CustomDialect { requires_derived_table_alias: bool, division_operator: BinaryOperator, full_qualified_col: bool, + unnest_as_table_factor: bool, } impl Default for CustomDialect { @@ -474,6 +484,7 @@ impl Default for CustomDialect { requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, full_qualified_col: false, + unnest_as_table_factor: false, } } } @@ -582,6 +593,10 @@ impl Dialect for CustomDialect { fn full_qualified_col(&self) -> bool { self.full_qualified_col } + + fn unnest_as_table_factor(&self) -> bool { + self.unnest_as_table_factor + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -617,6 +632,7 @@ pub struct CustomDialectBuilder { requires_derived_table_alias: bool, division_operator: BinaryOperator, full_qualified_col: bool, + unnest_as_table_factor: bool, } impl Default for CustomDialectBuilder { @@ -649,6 +665,7 @@ impl CustomDialectBuilder { requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, full_qualified_col: false, + unnest_as_table_factor: false, } } @@ -673,6 +690,7 @@ impl CustomDialectBuilder { requires_derived_table_alias: self.requires_derived_table_alias, division_operator: self.division_operator, full_qualified_col: self.full_qualified_col, + unnest_as_table_factor: self.unnest_as_table_factor, } } @@ -800,4 +818,9 @@ impl CustomDialectBuilder { self.full_qualified_col = full_qualified_col; self } + + pub fn with_unnest_as_table_factor(mut self, _unnest_as_table_factor: bool) -> Self { + self.unnest_as_table_factor = _unnest_as_table_factor; + self + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index eaae4fe73d8c..e9f9f486ea9a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -32,7 +32,9 @@ use super::{ }, Unparser, }; +use crate::unparser::ast::UnnestRelationBuilder; use crate::unparser::utils::unproject_agg_exprs; +use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, tree_node::{TransformedResult, TreeNode}, @@ -40,7 +42,7 @@ use datafusion_common::{ }; use datafusion_expr::{ expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, Unnest, }; use sqlparser::ast::{self, Ident, SetExpr}; use std::sync::Arc; @@ -312,6 +314,19 @@ impl Unparser<'_> { .select_to_sql_recursively(&new_plan, query, select, relation); } + // Projection can be top-level plan for unnest relation + // The projection generated by the `RecursiveUnnestRewriter` from a UNNEST relation will have + // only one expression, which is the placeholder column generated by the rewriter. + if self.dialect.unnest_as_table_factor() + && p.expr.len() == 1 + && Self::is_unnest_placeholder(&p.expr[0]) + { + if let LogicalPlan::Unnest(unnest) = &p.input.as_ref() { + return self + .unnest_to_table_factor_sql(unnest, query, select, relation); + } + } + // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -678,7 +693,11 @@ impl Unparser<'_> { ) } LogicalPlan::EmptyRelation(_) => { - relation.empty(); + // An EmptyRelation could be behind an UNNEST node. If the dialect supports UNNEST as a table factor, + // a TableRelationBuilder will be created for the UNNEST node first. + if !relation.has_relation() { + relation.empty(); + } Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), @@ -708,6 +727,38 @@ impl Unparser<'_> { } } + /// Try to find the placeholder column name generated by `RecursiveUnnestRewriter` + /// Only match the pattern `Expr::Alias(Expr::Column("__unnest_placeholder(...)"))` + fn is_unnest_placeholder(expr: &Expr) -> bool { + if let Expr::Alias(Alias { expr, .. }) = expr { + if let Expr::Column(Column { name, .. }) = expr.as_ref() { + return name.starts_with(UNNEST_PLACEHOLDER); + } + } + false + } + + fn unnest_to_table_factor_sql( + &self, + unnest: &Unnest, + query: &mut Option, + select: &mut SelectBuilder, + relation: &mut RelationBuilder, + ) -> Result<()> { + let mut unnest_relation = UnnestRelationBuilder::default(); + let LogicalPlan::Projection(p) = unnest.input.as_ref() else { + return internal_err!("Unnest input is not a Projection: {unnest:?}"); + }; + let exprs = p + .expr + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + unnest_relation.array_exprs(exprs); + relation.unnest(unnest_relation); + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) + } + fn is_scan_with_pushdown(scan: &TableScan) -> bool { scan.projection.is_some() || !scan.filters.is_empty() || scan.fetch.is_some() } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 518781106c3b..354a68f60964 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -133,7 +133,7 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// Recursively identify Column expressions and transform them into the appropriate unnest expression /// -/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { expr.transform(|sub_expr| { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 69e3953341ef..1c2a3ea91a2b 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -315,6 +315,8 @@ pub(crate) fn rewrite_recursive_unnests_bottom_up( .collect::>()) } +pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder"; + /* This is only usedful when used with transform down up A full example of how the transformation works: @@ -360,9 +362,9 @@ impl RecursiveUnnestRewriter<'_> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); + let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name); let post_unnest_name = - format!("unnest_placeholder({},depth={})", inner_expr_name, level); + format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); @@ -693,17 +695,17 @@ mod tests { // Only the bottom most unnest exprs are transformed assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(3d_col,depth=2)") + vec![col("__unnest_placeholder(3d_col,depth=2)") .alias("UNNEST(UNNEST(3d_col))") .add( - col("unnest_placeholder(3d_col,depth=2)") + col("__unnest_placeholder(3d_col,depth=2)") .alias("UNNEST(UNNEST(3d_col))") ) .add(col("i64_col"))] ); column_unnests_eq( vec![ - "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]", ], &unnest_placeholder_columns, ); @@ -713,7 +715,7 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("3d_col").alias("unnest_placeholder(3d_col)"), + col("3d_col").alias("__unnest_placeholder(3d_col)"), col("i64_col") ] ); @@ -730,12 +732,12 @@ mod tests { assert_eq!( transformed_exprs, vec![ - (col("unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) + (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) .alias("2d_col") ] ); column_unnests_eq( - vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], + vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, @@ -743,7 +745,7 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("3d_col").alias("unnest_placeholder(3d_col)"), + col("3d_col").alias("__unnest_placeholder(3d_col)"), col("i64_col") ] ); @@ -794,19 +796,19 @@ mod tests { assert_eq!( transformed_exprs, vec![ - col("unnest_placeholder(struct_col).field1"), - col("unnest_placeholder(struct_col).field2"), + col("__unnest_placeholder(struct_col).field1"), + col("__unnest_placeholder(struct_col).field2"), ] ); column_unnests_eq( - vec!["unnest_placeholder(struct_col)"], + vec!["__unnest_placeholder(struct_col)"], &unnest_placeholder_columns, ); // Still reference struct_col in original schema but with alias, // to avoid colliding with the projection on the column itself if any assert_eq!( inner_projection_exprs, - vec![col("struct_col").alias("unnest_placeholder(struct_col)"),] + vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),] ); // unnest(array_col) + 1 @@ -819,15 +821,15 @@ mod tests { )?; column_unnests_eq( vec![ - "unnest_placeholder(struct_col)", - "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", + "__unnest_placeholder(struct_col)", + "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); // Only transform the unnest children assert_eq!( transformed_exprs, - vec![col("unnest_placeholder(array_col,depth=1)") + vec![col("__unnest_placeholder(array_col,depth=1)") .alias("UNNEST(array_col)") .add(lit(1i64))] ); @@ -838,8 +840,8 @@ mod tests { assert_eq!( inner_projection_exprs, vec![ - col("struct_col").alias("unnest_placeholder(struct_col)"), - col("array_col").alias("unnest_placeholder(array_col)") + col("struct_col").alias("__unnest_placeholder(struct_col)"), + col("array_col").alias("__unnest_placeholder(array_col)") ] ); @@ -907,7 +909,7 @@ mod tests { assert_eq!( transformed_exprs, vec![unnest( - col("unnest_placeholder(struct_list,depth=1)") + col("__unnest_placeholder(struct_list,depth=1)") .alias("UNNEST(struct_list)") .field("subfield1") )] @@ -915,14 +917,14 @@ mod tests { column_unnests_eq( vec![ - "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + vec![col("struct_list").alias("__unnest_placeholder(struct_list)")] ); // continue rewrite another expr in select @@ -937,7 +939,7 @@ mod tests { assert_eq!( transformed_exprs, vec![unnest( - col("unnest_placeholder(struct_list,depth=1)") + col("__unnest_placeholder(struct_list,depth=1)") .alias("UNNEST(struct_list)") .field("subfield2") )] @@ -947,14 +949,14 @@ mod tests { // because expr1 and expr2 derive from the same unnest result column_unnests_eq( vec![ - "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]", ], &unnest_placeholder_columns, ); assert_eq!( inner_projection_exprs, - vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + vec![col("struct_list").alias("__unnest_placeholder(struct_list)")] ); Ok(()) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index fcfee29f6ac9..236b59432a5f 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -525,6 +525,96 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(SqliteDialect {}), }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3])", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + expected: r#"SELECT * FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS u (c1) UNION ALL SELECT * FROM (SELECT UNNEST([4, 5, 6]) AS "UNNEST(make_array(Int64(4),Int64(5),Int64(6)))") AS u (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3])", + expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) AS t1 (c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS t1 (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]), j1", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) CROSS JOIN j1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) JOIN j1 ON u.c1 = j1.j1_id", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) JOIN j1 ON (u.c1 = j1.j1_id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT * FROM UNNEST([1,2,3]) u(c1) UNION ALL SELECT * FROM UNNEST([4,5,6]) u(c1)", + expected: r#"SELECT * FROM UNNEST([1, 2, 3]) AS u (c1) UNION ALL SELECT * FROM UNNEST([4, 5, 6]) AS u (c1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3])", + expected: r#"SELECT * FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3]) as c1", + expected: r#"SELECT UNNEST([1, 2, 3]) AS c1"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT UNNEST([1,2,3]), 1", + expected: r#"SELECT UNNEST([1, 2, 3]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3))), Int64(1)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, ]; for query in tests { @@ -535,7 +625,8 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let state = MockSessionState::default() .with_aggregate_function(max_udaf()) .with_aggregate_function(min_udaf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(NestedFunctionPlanner)); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); @@ -571,9 +662,9 @@ fn test_unnest_logical_plan() -> Result<()> { let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); let expected = r#" -Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placeholder(unnest_table.struct_col).field2, unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col - Unnest: lists[unnest_placeholder(unnest_table.array_col)|depth=1] structs[unnest_placeholder(unnest_table.struct_col)] - Projection: unnest_table.struct_col AS unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col +Projection: __unnest_placeholder(unnest_table.struct_col).field1, __unnest_placeholder(unnest_table.struct_col).field2, __unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col + Unnest: lists[__unnest_placeholder(unnest_table.array_col)|depth=1] structs[__unnest_placeholder(unnest_table.struct_col)] + Projection: unnest_table.struct_col AS __unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS __unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col TableScan: unnest_table"#.trim_start(); assert_eq!(plan.to_string(), expected); diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index fc22cc8bf7a7..24efb33f7896 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -101,4 +101,4 @@ FROM test_utf8view; Andrew QW5kcmV3 416e64726577 X WA 58 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Raphael UmFwaGFlbA 5261706861656c R Ug 52 -NULL NULL NULL R Ug 52 \ No newline at end of file +NULL NULL NULL R Ug 52 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 62f625119897..49aaa877caa6 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4058,9 +4058,9 @@ logical_plan 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series 05)----Subquery: -06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i -07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] -08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) +06)------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i +07)--------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) 09)------------EmptyRelation physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t1" }), name: "t1_int" }) @@ -4081,9 +4081,9 @@ logical_plan 03)----TableScan: join_t1 projection=[t1_id, t1_name] 04)--SubqueryAlias: series 05)----Subquery: -06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i -07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] -08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) +06)------Projection: __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i +07)--------Unnest: lists[__unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS __unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) 09)------------EmptyRelation physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(UInt32, Column { relation: Some(Bare { table: "t2" }), name: "t1_int" }) diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 86aa07b04ce1..64cc51b3c4ff 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -36,9 +36,9 @@ query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -03)----Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +03)----Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 04)------Filter: v.column1 = Int64(2) 05)--------TableScan: v projection=[column1, column2] @@ -53,11 +53,11 @@ query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Projection: unnest_placeholder(v.column2,depth=1) -04)------Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -05)--------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Projection: __unnest_placeholder(v.column2,depth=1) +04)------Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +05)--------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 06)----------TableScan: v projection=[column1, column2] query II @@ -71,10 +71,10 @@ query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 05)--------Filter: v.column1 = Int64(2) 06)----------TableScan: v projection=[column1, column2] @@ -90,10 +90,10 @@ query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ---- logical_plan -01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) -03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) +03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 05)--------TableScan: v projection=[column1, column2] statement ok @@ -112,10 +112,10 @@ query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- logical_plan -01)Projection: d.column1, unnest_placeholder(d.column2,depth=1) AS o -02)--Filter: get_field(unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) -03)----Unnest: lists[unnest_placeholder(d.column2)|depth=1] structs[] -04)------Projection: d.column1, d.column2 AS unnest_placeholder(d.column2) +01)Projection: d.column1, __unnest_placeholder(d.column2,depth=1) AS o +02)--Filter: get_field(__unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) +03)----Unnest: lists[__unnest_placeholder(d.column2)|depth=1] structs[] +04)------Projection: d.column1, d.column2 AS __unnest_placeholder(d.column2) 05)--------TableScan: d projection=[column1, column2] diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 12402e0d70c5..79294993dded 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -139,4 +139,4 @@ SELECT generate_series(1, t1.end) FROM generate_series(3, 5) as t1(end) ---- [1, 2, 3, 4, 5] [1, 2, 3, 4] -[1, 2, 3] \ No newline at end of file +[1, 2, 3] diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index d409e0902f7e..1c54006bd2a0 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -594,17 +594,17 @@ query TT explain select unnest(unnest(column3)), column3 from recursive_unnest_table; ---- logical_plan -01)Unnest: lists[] structs[unnest_placeholder(UNNEST(recursive_unnest_table.column3))] -02)--Projection: unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 -03)----Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] -04)------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +01)Unnest: lists[] structs[__unnest_placeholder(UNNEST(recursive_unnest_table.column3))] +02)--Projection: __unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 +03)----Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +04)------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 05)--------TableScan: recursive_unnest_table projection=[column3] physical_plan 01)UnnestExec 02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -03)----ProjectionExec: expr=[unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] +03)----ProjectionExec: expr=[__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] 04)------UnnestExec -05)--------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +05)--------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 06)----------MemoryExec: partitions=1, partition_sizes=[1] ## unnest->field_access->unnest->unnest @@ -650,19 +650,19 @@ query TT explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; ---- logical_plan -01)Projection: unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 -02)--Unnest: lists[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] -03)----Projection: get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 -04)------Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] -05)--------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +01)Projection: __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 +02)--Unnest: lists[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] +03)----Projection: get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +04)------Unnest: lists[__unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +05)--------Projection: recursive_unnest_table.column3 AS __unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 06)----------TableScan: recursive_unnest_table projection=[column3] physical_plan -01)ProjectionExec: expr=[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] +01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] +03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------UnnestExec -06)----------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 6c48ac68ab6b..188e2ae0915f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5127,4 +5127,3 @@ order by id; statement ok drop table t1; - From c697bb06a6eb7360a3bf4d2184da8a2bc0567a3d Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Mon, 16 Dec 2024 20:36:45 +0800 Subject: [PATCH 26/26] fix: take taplo formatter suggestion --- datafusion/core/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 12a09e2045b5..33a11a20d306 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -215,4 +215,3 @@ required-features = ["nested_expressions"] [[bench]] harness = false name = "scalar_regex_match_query_sql" -