From f0bf01691f3f496e6b37c8994d5708273bbdb125 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 11 Jun 2023 06:57:14 -0400 Subject: [PATCH 1/7] Minor: Move `PlanType`, `StringifiedPlan` and `ToStringifiedPlan` `datafusion_common` (#6571) * Move DisplayablePlan to `datafusion_common` * Update uses * Update datafusion/common/src/display.rs Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- datafusion/common/src/display.rs | 110 +++++++++++++++++++ datafusion/common/src/lib.rs | 1 + datafusion/core/src/physical_plan/display.rs | 2 +- datafusion/core/src/physical_plan/explain.rs | 7 +- datafusion/core/src/physical_plan/planner.rs | 4 +- datafusion/expr/src/logical_plan/builder.rs | 8 +- datafusion/expr/src/logical_plan/plan.rs | 90 +-------------- 7 files changed, 124 insertions(+), 98 deletions(-) create mode 100644 datafusion/common/src/display.rs diff --git a/datafusion/common/src/display.rs b/datafusion/common/src/display.rs new file mode 100644 index 000000000000..79de9bc031d6 --- /dev/null +++ b/datafusion/common/src/display.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Types for plan display + +use std::{ + fmt::{self, Display, Formatter}, + sync::Arc, +}; + +/// Represents which type of plan, when storing multiple +/// for use in EXPLAIN plans +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PlanType { + /// The initial LogicalPlan provided to DataFusion + InitialLogicalPlan, + /// The LogicalPlan which results from applying an analyzer pass + AnalyzedLogicalPlan { + /// The name of the analyzer which produced this plan + analyzer_name: String, + }, + /// The LogicalPlan after all analyzer passes have been applied + FinalAnalyzedLogicalPlan, + /// The LogicalPlan which results from applying an optimizer pass + OptimizedLogicalPlan { + /// The name of the optimizer which produced this plan + optimizer_name: String, + }, + /// The final, fully optimized LogicalPlan that was converted to a physical plan + FinalLogicalPlan, + /// The initial physical plan, prepared for execution + InitialPhysicalPlan, + /// The ExecutionPlan which results from applying an optimizer pass + OptimizedPhysicalPlan { + /// The name of the optimizer which produced this plan + optimizer_name: String, + }, + /// The final, fully optimized physical which would be executed + FinalPhysicalPlan, +} + +impl Display for PlanType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + PlanType::InitialLogicalPlan => write!(f, "initial_logical_plan"), + PlanType::AnalyzedLogicalPlan { analyzer_name } => { + write!(f, "logical_plan after {analyzer_name}") + } + PlanType::FinalAnalyzedLogicalPlan => write!(f, "analyzed_logical_plan"), + PlanType::OptimizedLogicalPlan { optimizer_name } => { + write!(f, "logical_plan after {optimizer_name}") + } + PlanType::FinalLogicalPlan => write!(f, "logical_plan"), + PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), + PlanType::OptimizedPhysicalPlan { optimizer_name } => { + write!(f, "physical_plan after {optimizer_name}") + } + PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), + } + } +} + +/// Represents some sort of execution plan, in String form +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StringifiedPlan { + /// An identifier of what type of plan this string represents + pub plan_type: PlanType, + /// The string representation of the plan + pub plan: Arc, +} + +impl StringifiedPlan { + /// Create a new Stringified plan of `plan_type` with string + /// representation `plan` + pub fn new(plan_type: PlanType, plan: impl Into) -> Self { + StringifiedPlan { + plan_type, + plan: Arc::new(plan.into()), + } + } + + /// Returns true if this plan should be displayed. Generally + /// `verbose_mode = true` will display all available plans + pub fn should_display(&self, verbose_mode: bool) -> bool { + match self.plan_type { + PlanType::FinalLogicalPlan | PlanType::FinalPhysicalPlan => true, + _ => verbose_mode, + } + } +} + +/// Trait for something that can be formatted as a stringified plan +pub trait ToStringifiedPlan { + /// Create a stringified plan with the specified type + fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan; +} diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index ef7e0947008a..80b8a8597757 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -20,6 +20,7 @@ mod column; pub mod config; pub mod delta; mod dfschema; +pub mod display; mod error; mod join_type; pub mod parsers; diff --git a/datafusion/core/src/physical_plan/display.rs b/datafusion/core/src/physical_plan/display.rs index 5f286eed185c..2fba06ed29c1 100644 --- a/datafusion/core/src/physical_plan/display.rs +++ b/datafusion/core/src/physical_plan/display.rs @@ -21,7 +21,7 @@ use std::fmt; -use crate::logical_expr::{StringifiedPlan, ToStringifiedPlan}; +use datafusion_common::display::{StringifiedPlan, ToStringifiedPlan}; use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; diff --git a/datafusion/core/src/physical_plan/explain.rs b/datafusion/core/src/physical_plan/explain.rs index fc70626d9ba0..e40512f1b142 100644 --- a/datafusion/core/src/physical_plan/explain.rs +++ b/datafusion/core/src/physical_plan/explain.rs @@ -20,12 +20,11 @@ use std::any::Any; use std::sync::Arc; +use datafusion_common::display::StringifiedPlan; + use datafusion_common::{DataFusionError, Result}; -use crate::{ - logical_expr::StringifiedPlan, - physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}, -}; +use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics}; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use log::trace; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 6f45b7b5452d..38e125dc9338 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -32,8 +32,10 @@ use crate::logical_expr::{ }; use crate::logical_expr::{ CrossJoin, Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, - Repartition, ToStringifiedPlan, Union, UserDefinedLogicalNode, + Repartition, Union, UserDefinedLogicalNode, }; +use datafusion_common::display::ToStringifiedPlan; + use crate::logical_expr::{Limit, Values}; use crate::physical_expr::create_physical_expr; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 606b990cfe9e..99489637da23 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -29,8 +29,8 @@ use crate::{ logical_plan::{ Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, ToStringifiedPlan, - Union, Unnest, Values, Window, + Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, + Window, }, utils::{ can_hash, expand_qualified_wildcard, expand_wildcard, @@ -40,8 +40,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, - ScalarValue, TableReference, ToDFSchema, + display::ToStringifiedPlan, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, + OwnedTableReference, Result, ScalarValue, TableReference, ToDFSchema, }; use std::any::Any; use std::cmp::Ordering; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e19b327785a2..deaea25145e1 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -42,7 +42,8 @@ use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -// backwards compatible +// backwards compatibility +pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; use super::DdlStatement; @@ -1650,93 +1651,6 @@ pub enum Partitioning { DistributeBy(Vec), } -/// Represents which type of plan, when storing multiple -/// for use in EXPLAIN plans -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum PlanType { - /// The initial LogicalPlan provided to DataFusion - InitialLogicalPlan, - /// The LogicalPlan which results from applying an analyzer pass - AnalyzedLogicalPlan { - /// The name of the analyzer which produced this plan - analyzer_name: String, - }, - /// The LogicalPlan after all analyzer passes have been applied - FinalAnalyzedLogicalPlan, - /// The LogicalPlan which results from applying an optimizer pass - OptimizedLogicalPlan { - /// The name of the optimizer which produced this plan - optimizer_name: String, - }, - /// The final, fully optimized LogicalPlan that was converted to a physical plan - FinalLogicalPlan, - /// The initial physical plan, prepared for execution - InitialPhysicalPlan, - /// The ExecutionPlan which results from applying an optimizer pass - OptimizedPhysicalPlan { - /// The name of the optimizer which produced this plan - optimizer_name: String, - }, - /// The final, fully optimized physical which would be executed - FinalPhysicalPlan, -} - -impl Display for PlanType { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - PlanType::InitialLogicalPlan => write!(f, "initial_logical_plan"), - PlanType::AnalyzedLogicalPlan { analyzer_name } => { - write!(f, "logical_plan after {analyzer_name}") - } - PlanType::FinalAnalyzedLogicalPlan => write!(f, "analyzed_logical_plan"), - PlanType::OptimizedLogicalPlan { optimizer_name } => { - write!(f, "logical_plan after {optimizer_name}") - } - PlanType::FinalLogicalPlan => write!(f, "logical_plan"), - PlanType::InitialPhysicalPlan => write!(f, "initial_physical_plan"), - PlanType::OptimizedPhysicalPlan { optimizer_name } => { - write!(f, "physical_plan after {optimizer_name}") - } - PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), - } - } -} - -/// Represents some sort of execution plan, in String form -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct StringifiedPlan { - /// An identifier of what type of plan this string represents - pub plan_type: PlanType, - /// The string representation of the plan - pub plan: Arc, -} - -impl StringifiedPlan { - /// Create a new Stringified plan of `plan_type` with string - /// representation `plan` - pub fn new(plan_type: PlanType, plan: impl Into) -> Self { - StringifiedPlan { - plan_type, - plan: Arc::new(plan.into()), - } - } - - /// returns true if this plan should be displayed. Generally - /// `verbose_mode = true` will display all available plans - pub fn should_display(&self, verbose_mode: bool) -> bool { - match self.plan_type { - PlanType::FinalLogicalPlan | PlanType::FinalPhysicalPlan => true, - _ => verbose_mode, - } - } -} - -/// Trait for something that can be formatted as a stringified plan -pub trait ToStringifiedPlan { - /// Create a stringified plan with the specified type - fn to_stringified(&self, plan_type: PlanType) -> StringifiedPlan; -} - /// Unnest a column that contains a nested list type. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Unnest { From c166a584201beaf3af43149d9ee515a5a40e3be3 Mon Sep 17 00:00:00 2001 From: jakevin Date: Sun, 11 Jun 2023 19:03:24 +0800 Subject: [PATCH 2/7] fix: correct test timestamp_add_interval_months (#6622) --- datafusion/core/tests/sql/timestamp.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 2058d8ed1fd6..df922844bbca 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -565,7 +565,6 @@ async fn timestamp_sub_interval_days() -> Result<()> { } #[tokio::test] -#[ignore] // https://github.com/apache/arrow-datafusion/issues/3327 async fn timestamp_add_interval_months() -> Result<()> { let ctx = SessionContext::new(); @@ -576,7 +575,7 @@ async fn timestamp_add_interval_months() -> Result<()> { let res1 = actual[0][0].as_str(); let res2 = actual[0][1].as_str(); - let format = "%Y-%m-%d %H:%M:%S%.6f"; + let format = "%Y-%m-%dT%H:%M:%S%.6fZ"; let t1_naive = chrono::NaiveDateTime::parse_from_str(res1, format).unwrap(); let t2_naive = chrono::NaiveDateTime::parse_from_str(res2, format).unwrap(); From c4a036dabc09c6f00b533468c62223b21f564cdd Mon Sep 17 00:00:00 2001 From: Folyd Date: Sun, 11 Jun 2023 19:05:08 +0800 Subject: [PATCH 3/7] Impl `Literal` trait for `NonZero*` types (#6627) --- datafusion/expr/src/literal.rs | 61 ++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index dc7412b5946c..effc31553819 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -88,6 +88,17 @@ macro_rules! make_literal { }; } +macro_rules! make_nonzero_literal { + ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { + #[doc = $DOC] + impl Literal for $TYPE { + fn lit(&self) -> Expr { + Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + } + } + }; +} + macro_rules! make_timestamp_literal { ($TYPE:ty, $SCALAR:ident, $DOC: expr) => { #[doc = $DOC] @@ -114,6 +125,47 @@ make_literal!(u16, UInt16, "literal expression containing a u16"); make_literal!(u32, UInt32, "literal expression containing a u32"); make_literal!(u64, UInt64, "literal expression containing a u64"); +make_nonzero_literal!( + std::num::NonZeroI8, + Int8, + "literal expression containing an i8" +); +make_nonzero_literal!( + std::num::NonZeroI16, + Int16, + "literal expression containing an i16" +); +make_nonzero_literal!( + std::num::NonZeroI32, + Int32, + "literal expression containing an i32" +); +make_nonzero_literal!( + std::num::NonZeroI64, + Int64, + "literal expression containing an i64" +); +make_nonzero_literal!( + std::num::NonZeroU8, + UInt8, + "literal expression containing a u8" +); +make_nonzero_literal!( + std::num::NonZeroU16, + UInt16, + "literal expression containing a u16" +); +make_nonzero_literal!( + std::num::NonZeroU32, + UInt32, + "literal expression containing a u32" +); +make_nonzero_literal!( + std::num::NonZeroU64, + UInt64, + "literal expression containing a u64" +); + make_timestamp_literal!(i8, Int8, "literal expression containing an i8"); make_timestamp_literal!(i16, Int16, "literal expression containing an i16"); make_timestamp_literal!(i32, Int32, "literal expression containing an i32"); @@ -124,10 +176,19 @@ make_timestamp_literal!(u32, UInt32, "literal expression containing a u32"); #[cfg(test)] mod test { + use std::num::NonZeroU32; + use super::*; use crate::expr_fn::col; use datafusion_common::ScalarValue; + #[test] + fn test_lit_nonzero() { + let expr = col("id").eq(lit(NonZeroU32::new(1).unwrap())); + let expected = col("id").eq(lit(ScalarValue::UInt32(Some(1)))); + assert_eq!(expr, expected); + } + #[test] fn test_lit_timestamp_nano() { let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32 From 57bc5b02990a66aed8bf142435e29c69f733e8b8 Mon Sep 17 00:00:00 2001 From: jakevin Date: Sun, 11 Jun 2023 20:24:42 +0800 Subject: [PATCH 4/7] style: make clippy happy and remove redundant prefix (#6624) * clippy * remove default() * style: remove redundant prefix in function.rs --- datafusion-cli/src/exec.rs | 4 +- datafusion-cli/src/object_storage.rs | 2 +- .../src/datasource/file_format/options.rs | 4 +- .../core/src/datasource/listing/table.rs | 4 +- .../src/datasource/listing_table_factory.rs | 4 +- datafusion/core/src/physical_plan/analyze.rs | 2 +- datafusion/expr/src/expr.rs | 6 +- datafusion/expr/src/function.rs | 454 ++++++------------ datafusion/expr/src/function_err.rs | 2 +- .../src/rewrite_disjunctive_predicate.rs | 2 +- .../physical-expr/src/expressions/binary.rs | 3 +- datafusion/proto/src/logical_plan/mod.rs | 2 +- datafusion/sql/src/parser.rs | 3 +- 13 files changed, 178 insertions(+), 314 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index cec0fe03739a..0debe240db01 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -244,13 +244,13 @@ mod tests { async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { let ctx = SessionContext::new(); - let plan = ctx.state().create_logical_plan(&sql).await?; + let plan = ctx.state().create_logical_plan(sql).await?; match &plan { LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) => { create_external_table(&ctx, cmd).await?; } - _ => assert!(false), + _ => unreachable!(), }; ctx.runtime_env() diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 46b03a0a36a2..e4b7033c34f2 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -57,7 +57,7 @@ pub async fn get_s3_object_store_builder( .ok_or_else(|| { DataFusionError::ObjectStore(object_store::Error::Generic { store: "S3", - source: format!("Failed to get S3 credentials from environment") + source: "Failed to get S3 credentials from environment".to_string() .into(), }) })? diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 3e802362d3ae..5694bf5380d5 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -512,7 +512,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { #[async_trait] impl ReadOptions<'_> for AvroReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { - let file_format = AvroFormat::default(); + let file_format = AvroFormat; ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) @@ -535,7 +535,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { #[async_trait] impl ReadOptions<'_> for ArrowReadOptions<'_> { fn to_listing_options(&self, config: &SessionConfig) -> ListingOptions { - let file_format = ArrowFormat::default(); + let file_format = ArrowFormat; ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 0252e99ab8a5..f6b0183f21e6 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -133,8 +133,8 @@ impl ListingTableConfig { .map_err(|_| DataFusionError::Internal(err_msg))?; let file_format: Arc = match file_type { - FileType::ARROW => Arc::new(ArrowFormat::default()), - FileType::AVRO => Arc::new(AvroFormat::default()), + FileType::ARROW => Arc::new(ArrowFormat), + FileType::AVRO => Arc::new(AvroFormat), FileType::CSV => Arc::new( CsvFormat::default().with_file_compression_type(file_compression_type), ), diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 7d10fc8e0e89..4bc6c124150b 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -78,11 +78,11 @@ impl TableProviderFactory for ListingTableFactory { .with_file_compression_type(file_compression_type), ), FileType::PARQUET => Arc::new(ParquetFormat::default()), - FileType::AVRO => Arc::new(AvroFormat::default()), + FileType::AVRO => Arc::new(AvroFormat), FileType::JSON => Arc::new( JsonFormat::default().with_file_compression_type(file_compression_type), ), - FileType::ARROW => Arc::new(ArrowFormat::default()), + FileType::ARROW => Arc::new(ArrowFormat), }; let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index 3923033d2e6e..2e4441e307ab 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -209,7 +209,7 @@ fn create_output_batch( plan_builder.append_value(total_rows.to_string()); type_builder.append_value("Duration"); - plan_builder.append_value(format!("{:?}", duration)); + plan_builder.append_value(format!("{duration:?}")); } RecordBatch::try_new( diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 86480f9a96b5..6ac2404f9dee 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -992,7 +992,7 @@ impl fmt::Debug for Expr { write!(f, " FILTER (WHERE {fe})")?; } if let Some(ob) = order_by { - write!(f, " ORDER BY {:?}", ob)?; + write!(f, " ORDER BY {ob:?}")?; } Ok(()) } @@ -1008,7 +1008,7 @@ impl fmt::Debug for Expr { write!(f, " FILTER (WHERE {fe})")?; } if let Some(ob) = order_by { - write!(f, " ORDER BY {:?}", ob)?; + write!(f, " ORDER BY {ob:?}")?; } Ok(()) } @@ -1374,7 +1374,7 @@ fn create_name(e: &Expr) -> Result { info += &format!(" FILTER (WHERE {fe})"); } if let Some(ob) = order_by { - info += &format!(" ORDER BY ({:?})", ob); + info += &format!(" ORDER BY ({ob:?})"); } Ok(format!("{}({}){}", fun.name, names.join(","), info)) } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bec672ab6f6c..f47b94322d94 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -97,6 +97,9 @@ pub fn return_type( fun: &BuiltinScalarFunction, input_expr_types: &[DataType], ) -> Result { + use DataType::*; + use TimeUnit::*; + // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -116,7 +119,7 @@ pub fn return_type( // Some built-in functions' return type depends on the incoming type. match fun { BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -126,7 +129,7 @@ pub fn return_type( ))), }, BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -135,18 +138,18 @@ pub fn return_type( "The {fun} function can only accept fixed size list as the args." ))), }, - BuiltinScalarFunction::ArrayDims => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayFill => Ok(DataType::List(Arc::new(Field::new( + BuiltinScalarFunction::ArrayDims => Ok(UInt8), + BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( "item", input_expr_types[0].clone(), true, )))), - BuiltinScalarFunction::ArrayLength => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayNdims => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayPosition => Ok(DataType::UInt8), - BuiltinScalarFunction::ArrayPositions => Ok(DataType::UInt8), + BuiltinScalarFunction::ArrayLength => Ok(UInt8), + BuiltinScalarFunction::ArrayNdims => Ok(UInt8), + BuiltinScalarFunction::ArrayPosition => Ok(UInt8), + BuiltinScalarFunction::ArrayPositions => Ok(UInt8), BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -156,7 +159,7 @@ pub fn return_type( ))), }, BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -166,7 +169,7 @@ pub fn return_type( ))), }, BuiltinScalarFunction::ArrayReplace => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -176,7 +179,7 @@ pub fn return_type( ))), }, BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -185,14 +188,14 @@ pub fn return_type( "The {fun} function can only accept list as the first argument" ))), }, - BuiltinScalarFunction::Cardinality => Ok(DataType::UInt64), - BuiltinScalarFunction::MakeArray => Ok(DataType::List(Arc::new(Field::new( + BuiltinScalarFunction::Cardinality => Ok(UInt64), + BuiltinScalarFunction::MakeArray => Ok(List(Arc::new(Field::new( "item", input_expr_types[0].clone(), true, )))), BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { - DataType::List(field) => Ok(DataType::List(Arc::new(Field::new( + List(field) => Ok(List(Arc::new(Field::new( "item", field.data_type().clone(), true, @@ -201,7 +204,7 @@ pub fn return_type( "The {fun} function can only accept list as the first argument" ))), }, - BuiltinScalarFunction::Ascii => Ok(DataType::Int32), + BuiltinScalarFunction::Ascii => Ok(Int32), BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") } @@ -209,29 +212,21 @@ pub fn return_type( BuiltinScalarFunction::CharacterLength => { utf8_to_int_type(&input_expr_types[0], "character_length") } - BuiltinScalarFunction::Chr => Ok(DataType::Utf8), + BuiltinScalarFunction::Chr => Ok(Utf8), BuiltinScalarFunction::Coalesce => { // COALESCE has multiple args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &signature(fun)); coerced_types.map(|types| types[0].clone()) } - BuiltinScalarFunction::Concat => Ok(DataType::Utf8), - BuiltinScalarFunction::ConcatWithSeparator => Ok(DataType::Utf8), - BuiltinScalarFunction::DatePart => Ok(DataType::Float64), + BuiltinScalarFunction::Concat => Ok(Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), + BuiltinScalarFunction::DatePart => Ok(Float64), BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin => { match input_expr_types[1] { - DataType::Timestamp(TimeUnit::Nanosecond, _) | DataType::Utf8 => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - DataType::Timestamp(TimeUnit::Second, _) => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } + Timestamp(Nanosecond, _) | Utf8 => Ok(Timestamp(Nanosecond, None)), + Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), + Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), + Timestamp(Second, _) => Ok(Timestamp(Second, None)), _ => Err(DataFusionError::Internal(format!( "The {fun} function can only accept timestamp as the second arg." ))), @@ -253,9 +248,9 @@ pub fn return_type( BuiltinScalarFunction::OctetLength => { utf8_to_int_type(&input_expr_types[0], "octet_length") } - BuiltinScalarFunction::Pi => Ok(DataType::Float64), - BuiltinScalarFunction::Random => Ok(DataType::Float64), - BuiltinScalarFunction::Uuid => Ok(DataType::Utf8), + BuiltinScalarFunction::Pi => Ok(Float64), + BuiltinScalarFunction::Random => Ok(Float64), + BuiltinScalarFunction::Uuid => Ok(Utf8), BuiltinScalarFunction::RegexpReplace => { utf8_to_str_type(&input_expr_types[0], "regex_replace") } @@ -287,13 +282,11 @@ pub fn return_type( BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } - BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean), + BuiltinScalarFunction::StartsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => utf8_to_int_type(&input_expr_types[0], "strpos"), BuiltinScalarFunction::Substr => utf8_to_str_type(&input_expr_types[0], "substr"), BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Utf8 - } + Int8 | Int16 | Int32 | Int64 => Utf8, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( @@ -301,40 +294,23 @@ pub fn return_type( )); } }), - BuiltinScalarFunction::ToTimestamp => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - BuiltinScalarFunction::ToTimestampMillis => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - BuiltinScalarFunction::ToTimestampMicros => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - BuiltinScalarFunction::ToTimestampSeconds => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - BuiltinScalarFunction::FromUnixtime => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - BuiltinScalarFunction::Now => Ok(DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+00:00".into()), - )), - BuiltinScalarFunction::CurrentDate => Ok(DataType::Date32), - BuiltinScalarFunction::CurrentTime => Ok(DataType::Time64(TimeUnit::Nanosecond)), + BuiltinScalarFunction::ToTimestamp => Ok(Timestamp(Nanosecond, None)), + BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), + BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), + BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::Now => Ok(Timestamp(Nanosecond, Some("+00:00".into()))), + BuiltinScalarFunction::CurrentDate => Ok(Date32), + BuiltinScalarFunction::CurrentTime => Ok(Time64(Nanosecond)), BuiltinScalarFunction::Translate => { utf8_to_str_type(&input_expr_types[0], "translate") } BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), BuiltinScalarFunction::Upper => utf8_to_str_type(&input_expr_types[0], "upper"), BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { - DataType::LargeUtf8 => { - DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, true))) - } - DataType::Utf8 => { - DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))) - } - DataType::Null => DataType::Null, + LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), + Utf8 => List(Arc::new(Field::new("item", Utf8, true))), + Null => Null, _ => { // this error is internal as `data_types` should have captured this. return Err(DataFusionError::Internal( @@ -345,11 +321,11 @@ pub fn return_type( BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd - | BuiltinScalarFunction::Lcm => Ok(DataType::Int64), + | BuiltinScalarFunction::Lcm => Ok(Int64), BuiltinScalarFunction::Power => match &input_expr_types[0] { - DataType::Int64 => Ok(DataType::Int64), - _ => Ok(DataType::Float64), + Int64 => Ok(Int64), + _ => Ok(Float64), }, BuiltinScalarFunction::Struct => { @@ -358,20 +334,20 @@ pub fn return_type( .enumerate() .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) .collect::>(); - Ok(DataType::Struct(Fields::from(return_fields))) + Ok(Struct(Fields::from(return_fields))) } BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - DataType::Float32 => Ok(DataType::Float32), - _ => Ok(DataType::Float64), + Float32 => Ok(Float32), + _ => Ok(Float64), }, BuiltinScalarFunction::Log => match &input_expr_types[0] { - DataType::Float32 => Ok(DataType::Float32), - _ => Ok(DataType::Float64), + Float32 => Ok(Float32), + _ => Ok(Float64), }, - BuiltinScalarFunction::ArrowTypeof => Ok(DataType::Utf8), + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos @@ -399,14 +375,18 @@ pub fn return_type( | BuiltinScalarFunction::Tan | BuiltinScalarFunction::Tanh | BuiltinScalarFunction::Trunc => match input_expr_types[0] { - DataType::Float32 => Ok(DataType::Float32), - _ => Ok(DataType::Float64), + Float32 => Ok(Float32), + _ => Ok(Float64), }, } } /// Return the [`Signature`] supported by the function `fun`. pub fn signature(fun: &BuiltinScalarFunction) -> Signature { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + use TypeSignature::*; // note: the physical expression must accept the type returned by this function or the execution panics. // for now, the list is small, as we do not have many built-in functions. @@ -431,7 +411,7 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { fun.volatility(), ), BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { - Signature::variadic(vec![DataType::Utf8], fun.volatility()) + Signature::variadic(vec![Utf8], fun.volatility()) } BuiltinScalarFunction::Coalesce => Signature::variadic( conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), @@ -443,12 +423,7 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { | BuiltinScalarFunction::SHA512 | BuiltinScalarFunction::MD5 => Signature::uniform( 1, - vec![ - DataType::Utf8, - DataType::LargeUtf8, - DataType::Binary, - DataType::LargeBinary, - ], + vec![Utf8, LargeUtf8, Binary, LargeBinary], fun.volatility(), ), BuiltinScalarFunction::Ascii @@ -458,213 +433,146 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { | BuiltinScalarFunction::Lower | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => Signature::uniform( - 1, - vec![DataType::Utf8, DataType::LargeUtf8], - fun.volatility(), - ), + | BuiltinScalarFunction::Upper => { + Signature::uniform(1, vec![Utf8, LargeUtf8], fun.volatility()) + } BuiltinScalarFunction::Btrim | BuiltinScalarFunction::Ltrim | BuiltinScalarFunction::Rtrim | BuiltinScalarFunction::Trim => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - ], + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], fun.volatility(), ), BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { - Signature::uniform(1, vec![DataType::Int64], fun.volatility()) + Signature::uniform(1, vec![Int64], fun.volatility()) } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::LargeUtf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::LargeUtf8, - ]), + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), ], fun.volatility(), ), BuiltinScalarFunction::Left | BuiltinScalarFunction::Repeat | BuiltinScalarFunction::Right => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - ], + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], fun.volatility(), ), BuiltinScalarFunction::ToTimestamp => Signature::uniform( 1, vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, ], fun.volatility(), ), BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( 1, vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, ], fun.volatility(), ), BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( 1, vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, ], fun.volatility(), ), BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( 1, vec![ - DataType::Int64, - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Utf8, + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, ], fun.volatility(), ), BuiltinScalarFunction::FromUnixtime => { - Signature::uniform(1, vec![DataType::Int64], fun.volatility()) + Signature::uniform(1, vec![Int64], fun.volatility()) } BuiltinScalarFunction::Digest => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Binary, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeBinary, DataType::Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::DateTrunc => Signature::exact( - vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), ], fun.volatility(), ), + BuiltinScalarFunction::DateTrunc => { + Signature::exact(vec![Utf8, Timestamp(Nanosecond, None)], fun.volatility()) + } BuiltinScalarFunction::DateBin => { let base_sig = |array_type: TimeUnit| { vec![ - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Timestamp(array_type.clone(), None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::DayTime), - DataType::Timestamp(array_type.clone(), None), - DataType::Timestamp(TimeUnit::Nanosecond, None), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), ]), - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Timestamp(array_type.clone(), None), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), ]), - TypeSignature::Exact(vec![ - DataType::Interval(IntervalUnit::DayTime), - DataType::Timestamp(array_type, None), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), ]), + Exact(vec![Interval(DayTime), Timestamp(array_type, None)]), ] }; - let full_sig = [ - TimeUnit::Nanosecond, - TimeUnit::Microsecond, - TimeUnit::Millisecond, - TimeUnit::Second, - ] - .into_iter() - .map(base_sig) - .collect::>() - .concat(); + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .map(base_sig) + .collect::>() + .concat(); Signature::one_of(full_sig, fun.volatility()) } BuiltinScalarFunction::DatePart => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Date32]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::Date64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Second, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Microsecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Millisecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - ]), + Exact(vec![Utf8, Date32]), + Exact(vec![Utf8, Date64]), + Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![Utf8, Timestamp(Nanosecond, Some("+00:00".into()))]), ], fun.volatility(), ), BuiltinScalarFunction::SplitPart => Signature::one_of( vec![ - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Utf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::LargeUtf8, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::LargeUtf8, - DataType::Int64, - ]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8, Int64]), + Exact(vec![Utf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], fun.volatility(), ), @@ -672,10 +580,10 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), ], fun.volatility(), ) @@ -683,45 +591,21 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Substr => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Int64]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Int64, - DataType::Int64, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Int64, - DataType::Int64, - ]), + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), ], fun.volatility(), ), BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { - Signature::one_of( - vec![TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ])], - fun.volatility(), - ) + Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], fun.volatility()) } BuiltinScalarFunction::RegexpReplace => Signature::one_of( vec![ - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8, Utf8]), ], fun.volatility(), ), @@ -731,18 +615,10 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { } BuiltinScalarFunction::RegexpMatch => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), - TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), - TypeSignature::Exact(vec![ - DataType::Utf8, - DataType::Utf8, - DataType::Utf8, - ]), - TypeSignature::Exact(vec![ - DataType::LargeUtf8, - DataType::Utf8, - DataType::Utf8, - ]), + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), ], fun.volatility(), ), @@ -750,42 +626,36 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), BuiltinScalarFunction::Uuid => Signature::exact(vec![], fun.volatility()), BuiltinScalarFunction::Power => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), - TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), - ], + vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], fun.volatility(), ), BuiltinScalarFunction::Round => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), - TypeSignature::Exact(vec![DataType::Float32, DataType::Int64]), - TypeSignature::Exact(vec![DataType::Float64]), - TypeSignature::Exact(vec![DataType::Float32]), + Exact(vec![Float64, Int64]), + Exact(vec![Float32, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), ], fun.volatility(), ), BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![ - TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]), - TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), - ], + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], fun.volatility(), ), BuiltinScalarFunction::Log => Signature::one_of( vec![ - TypeSignature::Exact(vec![DataType::Float32]), - TypeSignature::Exact(vec![DataType::Float64]), - TypeSignature::Exact(vec![DataType::Float32, DataType::Float32]), - TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), + Exact(vec![Float32]), + Exact(vec![Float64]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), ], fun.volatility(), ), BuiltinScalarFunction::Factorial => { - Signature::uniform(1, vec![DataType::Int64], fun.volatility()) + Signature::uniform(1, vec![Int64], fun.volatility()) } BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { - Signature::uniform(2, vec![DataType::Int64], fun.volatility()) + Signature::uniform(2, vec![Int64], fun.volatility()) } BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()), BuiltinScalarFunction::Abs @@ -818,11 +688,7 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { // return the best approximation for it (in f64). // We accept f32 because in this case it is clear that the best approximation // will be as good as the number of digits in the number - Signature::uniform( - 1, - vec![DataType::Float64, DataType::Float32], - fun.volatility(), - ) + Signature::uniform(1, vec![Float64, Float32], fun.volatility()) } BuiltinScalarFunction::Now | BuiltinScalarFunction::CurrentDate diff --git a/datafusion/expr/src/function_err.rs b/datafusion/expr/src/function_err.rs index e97e0f92cd80..1635ac3b0c8c 100644 --- a/datafusion/expr/src/function_err.rs +++ b/datafusion/expr/src/function_err.rs @@ -84,7 +84,7 @@ pub fn generate_signature_error_msg( .type_signature .to_string_repr() .iter() - .map(|args_str| format!("\t{}({})", fun, args_str)) + .map(|args_str| format!("\t{fun}({args_str})")) .collect::>() .join("\n"); diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 57513fa4fff4..90c96b4b8b8c 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -119,7 +119,7 @@ pub struct RewriteDisjunctivePredicate; impl RewriteDisjunctivePredicate { pub fn new() -> Self { - Self::default() + Self } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index e5b66d4a3987..994159a8a77d 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1060,8 +1060,7 @@ fn to_result_type_array( Ok(cast(&array, result_type)?) } else { Err(DataFusionError::Internal(format!( - "Incompatible Dictionary value type {:?} with result type {:?} of Binary operator {:?}", - value_type, result_type, op + "Incompatible Dictionary value type {value_type:?} with result type {result_type:?} of Binary operator {op:?}" ))) } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 3774ce14305d..646d02384de5 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -353,7 +353,7 @@ impl AsLogicalPlan for LogicalPlanNode { .with_has_header(*has_header) .with_delimiter(str_to_byte(delimiter)?), ), - FileFormatType::Avro(..) => Arc::new(AvroFormat::default()), + FileFormatType::Avro(..) => Arc::new(AvroFormat), }; let table_paths = &scan diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 38dacf35be13..6bd6ffded223 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -664,8 +664,7 @@ impl<'a> DFParser<'a> { break; } else { return Err(ParserError::ParserError(format!( - "Unexpected token {}", - token + "Unexpected token {token}" ))); } } From e6265c1e3f3e4a3f2a80044f3659fdc0ab688246 Mon Sep 17 00:00:00 2001 From: Nuttiiya Seekhao <37189615+nseekhao@users.noreply.github.com> Date: Sun, 11 Jun 2023 08:36:29 -0400 Subject: [PATCH 5/7] Substrait: Fix incorrect join key fields (indices) when same table is being used more than once (#6135) * Fix incorrect join key fields (indices) when same table is being used more than once * Addressed comments Update datafusion/substrait/src/logical_plan/producer.rs Co-authored-by: Ruihang Xia Update datafusion/substrait/src/logical_plan/producer.rs Co-authored-by: Ruihang Xia * Fixed bugs after rebase --------- Co-authored-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 191 ++++++++++++------ .../substrait/tests/roundtrip_logical_plan.rs | 24 +++ 3 files changed, 155 insertions(+), 62 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f914b62a1452..f15ffdf42374 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -365,7 +365,7 @@ pub async fn from_substrait_rel( )), }, _ => Err(DataFusionError::Internal( - "invalid join condition expresssion".to_string(), + "invalid join condition expression".to_string(), )), } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 785bfa4ea6a7..228341548813 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -32,7 +32,7 @@ use datafusion::logical_expr::expr::{ BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::{binary_expr, Expr}; +use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::{ proto::{ @@ -156,7 +156,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -172,6 +172,7 @@ pub fn to_substrait_rel( let filter_expr = to_substrait_rex( &filter.predicate, filter.input.schema(), + 0, extension_info, )?; Ok(Box::new(Rel { @@ -218,7 +219,7 @@ pub fn to_substrait_rel( let grouping = agg .group_expr .iter() - .map(|e| to_substrait_rex(e, agg.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, agg.input.schema(), 0, extension_info)) .collect::>>()?; let measures = agg .aggr_expr @@ -281,45 +282,24 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - let join_expression = join - .on - .iter() - .map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone())) - .reduce(|acc: Expr, expr: Expr| acc.and(expr)); - // join schema from left and right to maintain all nececesary columns from inputs - // note that we cannot simple use join.schema here since we discard some input columns - // when performing semi and anti joins - let join_schema = match join.left.schema().join(join.right.schema()) { - Ok(schema) => Ok(schema), - Err(DataFusionError::SchemaError( - datafusion::common::SchemaError::DuplicateQualifiedField { - qualifier: _, - name: _, - }, - )) => Ok(join.schema.as_ref().clone()), - Err(e) => Err(e), - }; - if let Some(e) = join_expression { - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: Some(Box::new(to_substrait_rex( - &e, - &Arc::new(join_schema?), - extension_info, - )?)), - post_join_filter: None, - advanced_extension: None, - }))), - })) - } else { - Err(DataFusionError::NotImplemented( - "Empty join condition".to_string(), - )) - } + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: Some(Box::new(to_substrait_join_expr( + &join.on, + eq_op, + join.left.schema(), + join.right.schema(), + extension_info, + )?)), + post_join_filter: None, + advanced_extension: None, + }))), + })) } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias @@ -353,6 +333,7 @@ pub fn to_substrait_rel( window_exprs.push(to_substrait_rex( expr, window.input.schema(), + 0, extension_info, )?); } @@ -403,6 +384,40 @@ pub fn to_substrait_rel( } } +fn to_substrait_join_expr( + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + // Parse left + let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + // Parse right + let r = to_substrait_rex( + right, + right_schema, + left_schema.fields().len(), // offset to return the correct index + extension_info, + )?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + } + let join_expr: Expression = exprs + .into_iter() + .reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + }) + .unwrap(); + Ok(join_expr) +} + fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { match join_type { JoinType::Inner => join_rel::JoinType::Inner, @@ -459,7 +474,7 @@ pub fn to_substrait_agg_measure( Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => { let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } let function_name = fun.to_string().to_lowercase(); let function_anchor = _register_function(function_name, extension_info); @@ -478,7 +493,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, extension_info)?), + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), None => None } }) @@ -566,10 +581,33 @@ pub fn make_binary_op_scalar_func( } /// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// +/// * `expr` - DataFusion expression to be parse into a Substrait expression +/// * `schema` - DataFusion input schema for looking up field qualifiers +/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// This should only be set by caller with more than one input relations i.e. Join. +/// Substrait expects one set of indices when joining two relations. +/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` +/// relation will have column indices from `0` to `n-1`, however, Substrait will expect +/// the `right` indices to be offset by the `left`. This means Substrait will expect to +/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: +/// ```SELECT * +/// FROM t1 +/// JOIN t2 +/// ON t1.c1 = t2.c0;``` +/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] +/// the join condition should become +/// `col_ref(1) = col_ref(3 + 0)` +/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index +/// of the join key column from `right` +/// * `extension_info` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( expr: &Expr, schema: &DFSchemaRef, + col_ref_offset: usize, extension_info: &mut ( Vec, HashMap, @@ -583,6 +621,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); @@ -607,9 +646,12 @@ pub fn to_substrait_rex( }) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -632,9 +674,12 @@ pub fn to_substrait_rex( )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -659,11 +704,11 @@ pub fn to_substrait_rex( } Expr::Column(col) => { let index = schema.index_of_column(col)?; - substrait_field_ref(index) + substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, extension_info)?; - let r = to_substrait_rex(right, schema, extension_info)?; + let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -677,21 +722,41 @@ pub fn to_substrait_rex( if let Some(e) = expr { // Base expression exists ifs.push(IfClause { - r#if: Some(to_substrait_rex(e, schema, extension_info)?), + r#if: Some(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?), then: None, }); } // Parse `when`s for (r#if, then) in when_then_expr { ifs.push(IfClause { - r#if: Some(to_substrait_rex(r#if, schema, extension_info)?), - then: Some(to_substrait_rex(then, schema, extension_info)?), + r#if: Some(to_substrait_rex( + r#if, + schema, + col_ref_offset, + extension_info, + )?), + then: Some(to_substrait_rex( + then, + schema, + col_ref_offset, + extension_info, + )?), }); } // Parse outer `else` let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), + Some(e) => Some(Box::new(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?)), None => None, }; @@ -707,6 +772,7 @@ pub fn to_substrait_rex( input: Some(Box::new(to_substrait_rex( expr, schema, + col_ref_offset, extension_info, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED @@ -715,7 +781,9 @@ pub fn to_substrait_rex( }) } Expr::Literal(value) => to_substrait_literal(value), - Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), + Expr::Alias(expr, _alias) => { + to_substrait_rex(expr, schema, col_ref_offset, extension_info) + } Expr::WindowFunction(WindowFunction { fun, args, @@ -733,6 +801,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); @@ -740,7 +809,7 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, extension_info)) + .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by @@ -1325,7 +1394,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, extension_info)?; + let e = to_substrait_rex(expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 8cdf89b29473..e209ebedc0f3 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -412,6 +412,30 @@ mod tests { roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await } + #[tokio::test] + async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", + "Projection: data.b, data.c\ + \n Inner Join: data.a = data.a\ + \n TableScan: data projection=[a, b]\ + \n TableScan: data projection=[a, c]", + ) + .await + } + + #[tokio::test] + async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", + "Projection: data.b, data.c\ + \n Inner Join: data.b = data.b\ + \n TableScan: data projection=[b]\ + \n TableScan: data projection=[b, c]", + ) + .await + } + /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported From aebe71eace36f0c6683da031262971ad017b48c5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 11 Jun 2023 09:31:12 -0400 Subject: [PATCH 6/7] Minor: Add debug logging for schema mismatch errors (#6626) --- datafusion/core/src/datasource/memory.rs | 36 ++++++++++--------- datafusion/core/src/datasource/streaming.rs | 16 ++++++--- .../core/src/physical_plan/streaming.rs | 16 ++++++--- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index f66b44e9d1f9..3f6316d28ab5 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -18,6 +18,7 @@ //! [`MemTable`] for querying `Vec` by DataFusion. use futures::StreamExt; +use log::debug; use std::any::Any; use std::fmt::{self, Debug, Display}; use std::sync::Arc; @@ -55,23 +56,26 @@ pub struct MemTable { impl MemTable { /// Create a new in-memory table from the provided schema and record batches pub fn try_new(schema: SchemaRef, partitions: Vec>) -> Result { - if partitions - .iter() - .flatten() - .all(|batches| schema.contains(&batches.schema())) - { - Ok(Self { - schema, - batches: partitions - .into_iter() - .map(|e| Arc::new(RwLock::new(e))) - .collect::>(), - }) - } else { - Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )) + for batches in partitions.iter().flatten() { + let batches_schema = batches.schema(); + if !schema.contains(&batches_schema) { + debug!( + "mem table schema does not contain batches schema. \ + Target_schema: {schema:?}. Batches Schema: {batches_schema:?}" + ); + return Err(DataFusionError::Plan( + "Mismatch between schema and batches".to_string(), + )); + } } + + Ok(Self { + schema, + batches: partitions + .into_iter() + .map(|e| Arc::new(RwLock::new(e))) + .collect::>(), + }) } /// Create a mem table by reading from another data source diff --git a/datafusion/core/src/datasource/streaming.rs b/datafusion/core/src/datasource/streaming.rs index 4a234fbe138b..a5fc6f19290c 100644 --- a/datafusion/core/src/datasource/streaming.rs +++ b/datafusion/core/src/datasource/streaming.rs @@ -25,6 +25,7 @@ use async_trait::async_trait; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{Expr, TableType}; +use log::debug; use crate::datasource::TableProvider; use crate::execution::context::{SessionState, TaskContext}; @@ -53,10 +54,17 @@ impl StreamingTable { schema: SchemaRef, partitions: Vec>, ) -> Result { - if !partitions.iter().all(|x| schema.contains(x.schema())) { - return Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )); + for x in partitions.iter() { + let partition_schema = x.schema(); + if !schema.contains(partition_schema) { + debug!( + "target schema does not contain partition schema. \ + Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + ); + return Err(DataFusionError::Plan( + "Mismatch between schema and batches".to_string(), + )); + } } Ok(Self { diff --git a/datafusion/core/src/physical_plan/streaming.rs b/datafusion/core/src/physical_plan/streaming.rs index 0555c1ce2899..797e19c46737 100644 --- a/datafusion/core/src/physical_plan/streaming.rs +++ b/datafusion/core/src/physical_plan/streaming.rs @@ -26,6 +26,7 @@ use futures::stream::StreamExt; use datafusion_common::{DataFusionError, Result, Statistics}; use datafusion_physical_expr::PhysicalSortExpr; +use log::debug; use crate::datasource::streaming::PartitionStream; use crate::physical_plan::stream::RecordBatchStreamAdapter; @@ -48,10 +49,17 @@ impl StreamingTableExec { projection: Option<&Vec>, infinite: bool, ) -> Result { - if !partitions.iter().all(|x| schema.contains(x.schema())) { - return Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )); + for x in partitions.iter() { + let partition_schema = x.schema(); + if !schema.contains(partition_schema) { + debug!( + "target schema does not contain partition schema. \ + Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + ); + return Err(DataFusionError::Plan( + "Mismatch between schema and batches".to_string(), + )); + } } let projected_schema = match projection { From d024f379448d8ab9ef14e3581a2d844aceb50221 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 11 Jun 2023 10:19:28 -0400 Subject: [PATCH 7/7] Move functionality into `BuildInScalarFunction` (#6612) --- datafusion/expr/src/built_in_function.rs | 712 +++++++++++++++++- datafusion/expr/src/expr_schema.rs | 6 +- datafusion/expr/src/function.rs | 685 ++--------------- datafusion/expr/src/function_err.rs | 125 --- datafusion/expr/src/lib.rs | 1 - datafusion/expr/src/signature.rs | 42 ++ .../optimizer/src/analyzer/type_coercion.rs | 6 +- datafusion/physical-expr/src/functions.rs | 6 +- datafusion/sql/src/expr/function.rs | 2 +- 9 files changed, 806 insertions(+), 779 deletions(-) delete mode 100644 datafusion/expr/src/function_err.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 2272997fae06..0f229059d05c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -17,11 +17,17 @@ //! Built-in functions module contains all the built-in functions definitions. -use crate::Volatility; +use crate::nullif::SUPPORTED_NULLIF_TYPES; +use crate::type_coercion::functions::data_types; +use crate::{ + conditional_expressions, struct_expressions, Signature, TypeSignature, Volatility, +}; +use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{DataFusionError, Result}; use std::collections::HashMap; use std::fmt; use std::str::FromStr; +use std::sync::Arc; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -383,6 +389,672 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Uuid => Volatility::Volatile, } } + + /// Creates a detailed error message for a function with wrong signature. + /// + /// For example, a query like `select round(3.14, 1.1);` would yield: + /// ```text + /// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. + /// Candidate functions: + /// round(Float64, Int64) + /// round(Float32, Int64) + /// round(Float64) + /// round(Float32) + /// ``` + fn generate_signature_error_msg(&self, input_expr_types: &[DataType]) -> String { + let candidate_signatures = self + .signature() + .type_signature + .to_string_repr() + .iter() + .map(|args_str| format!("\t{self}({args_str})")) + .collect::>() + .join("\n"); + + format!( + "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", + self, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures + ) + } + + /// Returns the output [`DataType` of this function + pub fn return_type(self, input_expr_types: &[DataType]) -> Result { + use DataType::*; + use TimeUnit::*; + + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + if input_expr_types.is_empty() && !self.supports_zero_argument() { + return Err(DataFusionError::Plan( + self.generate_signature_error_msg(input_expr_types), + )); + } + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()).map_err(|_| { + DataFusionError::Plan(self.generate_signature_error_msg(input_expr_types)) + })?; + + // the return type of the built in function. + // Some built-in functions' return type depends on the incoming type. + match self { + BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept fixed size list as the args." + ))), + }, + BuiltinScalarFunction::ArrayDims => Ok(UInt8), + BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + BuiltinScalarFunction::ArrayLength => Ok(UInt8), + BuiltinScalarFunction::ArrayNdims => Ok(UInt8), + BuiltinScalarFunction::ArrayPosition => Ok(UInt8), + BuiltinScalarFunction::ArrayPositions => Ok(UInt8), + BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayReplace => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::Cardinality => Ok(UInt64), + BuiltinScalarFunction::MakeArray => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { + List(field) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, + BuiltinScalarFunction::Ascii => Ok(Int32), + BuiltinScalarFunction::BitLength => { + utf8_to_int_type(&input_expr_types[0], "bit_length") + } + BuiltinScalarFunction::Btrim => { + utf8_to_str_type(&input_expr_types[0], "btrim") + } + BuiltinScalarFunction::CharacterLength => { + utf8_to_int_type(&input_expr_types[0], "character_length") + } + BuiltinScalarFunction::Chr => Ok(Utf8), + BuiltinScalarFunction::Coalesce => { + // COALESCE has multiple args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &self.signature()); + coerced_types.map(|types| types[0].clone()) + } + BuiltinScalarFunction::Concat => Ok(Utf8), + BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), + BuiltinScalarFunction::DatePart => Ok(Float64), + BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin => { + match input_expr_types[1] { + Timestamp(Nanosecond, _) | Utf8 => Ok(Timestamp(Nanosecond, None)), + Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), + Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), + Timestamp(Second, _) => Ok(Timestamp(Second, None)), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept timestamp as the second arg." + ))), + } + } + BuiltinScalarFunction::InitCap => { + utf8_to_str_type(&input_expr_types[0], "initcap") + } + BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), + BuiltinScalarFunction::Lower => { + utf8_to_str_type(&input_expr_types[0], "lower") + } + BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), + BuiltinScalarFunction::Ltrim => { + utf8_to_str_type(&input_expr_types[0], "ltrim") + } + BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), + BuiltinScalarFunction::NullIf => { + // NULLIF has two args and they might get coerced, get a preview of this + let coerced_types = data_types(input_expr_types, &self.signature()); + coerced_types.map(|typs| typs[0].clone()) + } + BuiltinScalarFunction::OctetLength => { + utf8_to_int_type(&input_expr_types[0], "octet_length") + } + BuiltinScalarFunction::Pi => Ok(Float64), + BuiltinScalarFunction::Random => Ok(Float64), + BuiltinScalarFunction::Uuid => Ok(Utf8), + BuiltinScalarFunction::RegexpReplace => { + utf8_to_str_type(&input_expr_types[0], "regex_replace") + } + BuiltinScalarFunction::Repeat => { + utf8_to_str_type(&input_expr_types[0], "repeat") + } + BuiltinScalarFunction::Replace => { + utf8_to_str_type(&input_expr_types[0], "replace") + } + BuiltinScalarFunction::Reverse => { + utf8_to_str_type(&input_expr_types[0], "reverse") + } + BuiltinScalarFunction::Right => { + utf8_to_str_type(&input_expr_types[0], "right") + } + BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), + BuiltinScalarFunction::Rtrim => { + utf8_to_str_type(&input_expr_types[0], "rtrimp") + } + BuiltinScalarFunction::SHA224 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") + } + BuiltinScalarFunction::SHA256 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") + } + BuiltinScalarFunction::SHA384 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") + } + BuiltinScalarFunction::SHA512 => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") + } + BuiltinScalarFunction::Digest => { + utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") + } + BuiltinScalarFunction::SplitPart => { + utf8_to_str_type(&input_expr_types[0], "split_part") + } + BuiltinScalarFunction::StartsWith => Ok(Boolean), + BuiltinScalarFunction::Strpos => { + utf8_to_int_type(&input_expr_types[0], "strpos") + } + BuiltinScalarFunction::Substr => { + utf8_to_str_type(&input_expr_types[0], "substr") + } + BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { + Int8 | Int16 | Int32 | Int64 => Utf8, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The to_hex function can only accept integers.".to_string(), + )); + } + }), + BuiltinScalarFunction::ToTimestamp => Ok(Timestamp(Nanosecond, None)), + BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), + BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), + BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), + BuiltinScalarFunction::Now => { + Ok(Timestamp(Nanosecond, Some("+00:00".into()))) + } + BuiltinScalarFunction::CurrentDate => Ok(Date32), + BuiltinScalarFunction::CurrentTime => Ok(Time64(Nanosecond)), + BuiltinScalarFunction::Translate => { + utf8_to_str_type(&input_expr_types[0], "translate") + } + BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), + BuiltinScalarFunction::Upper => { + utf8_to_str_type(&input_expr_types[0], "upper") + } + BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { + LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), + Utf8 => List(Arc::new(Field::new("item", Utf8, true))), + Null => Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal( + "The regexp_extract function can only accept strings." + .to_string(), + )); + } + }), + + BuiltinScalarFunction::Factorial + | BuiltinScalarFunction::Gcd + | BuiltinScalarFunction::Lcm => Ok(Int64), + + BuiltinScalarFunction::Power => match &input_expr_types[0] { + Int64 => Ok(Int64), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Struct => { + let return_fields = input_expr_types + .iter() + .enumerate() + .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) + .collect::>(); + Ok(Struct(Fields::from(return_fields))) + } + + BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::Log => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), + + BuiltinScalarFunction::Abs + | BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Cosh + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Round + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc => match input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + } + } + + /// Return the argument [`Signature`] supported by this function + pub fn signature(&self) -> Signature { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + use TypeSignature::*; + // note: the physical expression must accept the type returned by this function or the execution panics. + + // for now, the list is small, as we do not have many built-in functions. + match self { + BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayConcat => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayLength => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayPosition => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayReplace => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayToString => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), + BuiltinScalarFunction::MakeArray => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::TrimArray => Signature::any(2, self.volatility()), + BuiltinScalarFunction::Struct => Signature::variadic( + struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), + self.volatility(), + ), + BuiltinScalarFunction::Concat + | BuiltinScalarFunction::ConcatWithSeparator => { + Signature::variadic(vec![Utf8], self.volatility()) + } + BuiltinScalarFunction::Coalesce => Signature::variadic( + conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), + self.volatility(), + ), + BuiltinScalarFunction::SHA224 + | BuiltinScalarFunction::SHA256 + | BuiltinScalarFunction::SHA384 + | BuiltinScalarFunction::SHA512 + | BuiltinScalarFunction::MD5 => Signature::uniform( + 1, + vec![Utf8, LargeUtf8, Binary, LargeBinary], + self.volatility(), + ), + BuiltinScalarFunction::Ascii + | BuiltinScalarFunction::BitLength + | BuiltinScalarFunction::CharacterLength + | BuiltinScalarFunction::InitCap + | BuiltinScalarFunction::Lower + | BuiltinScalarFunction::OctetLength + | BuiltinScalarFunction::Reverse + | BuiltinScalarFunction::Upper => { + Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) + } + BuiltinScalarFunction::Btrim + | BuiltinScalarFunction::Ltrim + | BuiltinScalarFunction::Rtrim + | BuiltinScalarFunction::Trim => Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + self.volatility(), + ), + BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + self.volatility(), + ) + } + BuiltinScalarFunction::Left + | BuiltinScalarFunction::Repeat + | BuiltinScalarFunction::Right => Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestamp => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( + 1, + vec![ + Int64, + Timestamp(Nanosecond, None), + Timestamp(Microsecond, None), + Timestamp(Millisecond, None), + Timestamp(Second, None), + Utf8, + ], + self.volatility(), + ), + BuiltinScalarFunction::FromUnixtime => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Digest => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::DateTrunc => Signature::exact( + vec![Utf8, Timestamp(Nanosecond, None)], + self.volatility(), + ), + BuiltinScalarFunction::DateBin => { + let base_sig = |array_type: TimeUnit| { + vec![ + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), + ]), + Exact(vec![ + Interval(DayTime), + Timestamp(array_type.clone(), None), + Timestamp(Nanosecond, None), + ]), + Exact(vec![ + Interval(MonthDayNano), + Timestamp(array_type.clone(), None), + ]), + Exact(vec![Interval(DayTime), Timestamp(array_type, None)]), + ] + }; + + let full_sig = [Nanosecond, Microsecond, Millisecond, Second] + .into_iter() + .map(base_sig) + .collect::>() + .concat(); + + Signature::one_of(full_sig, self.volatility()) + } + BuiltinScalarFunction::DatePart => Signature::one_of( + vec![ + Exact(vec![Utf8, Date32]), + Exact(vec![Utf8, Date64]), + Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![Utf8, Timestamp(Nanosecond, Some("+00:00".into()))]), + ], + self.volatility(), + ), + BuiltinScalarFunction::SplitPart => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8, Int64]), + Exact(vec![Utf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ) + } + + BuiltinScalarFunction::Substr => Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) + } + BuiltinScalarFunction::RegexpReplace => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8, Utf8]), + ], + self.volatility(), + ), + + BuiltinScalarFunction::NullIf => { + Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), self.volatility()) + } + BuiltinScalarFunction::RegexpMatch => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()), + BuiltinScalarFunction::Power => Signature::one_of( + vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Round => Signature::one_of( + vec![ + Exact(vec![Float64, Int64]), + Exact(vec![Float32, Int64]), + Exact(vec![Float64]), + Exact(vec![Float32]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Atan2 => Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + self.volatility(), + ), + BuiltinScalarFunction::Log => Signature::one_of( + vec![ + Exact(vec![Float32]), + Exact(vec![Float64]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], + self.volatility(), + ), + BuiltinScalarFunction::Factorial => { + Signature::uniform(1, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { + Signature::uniform(2, vec![Int64], self.volatility()) + } + BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()), + BuiltinScalarFunction::Abs + | BuiltinScalarFunction::Acos + | BuiltinScalarFunction::Asin + | BuiltinScalarFunction::Atan + | BuiltinScalarFunction::Acosh + | BuiltinScalarFunction::Asinh + | BuiltinScalarFunction::Atanh + | BuiltinScalarFunction::Cbrt + | BuiltinScalarFunction::Ceil + | BuiltinScalarFunction::Cos + | BuiltinScalarFunction::Cosh + | BuiltinScalarFunction::Degrees + | BuiltinScalarFunction::Exp + | BuiltinScalarFunction::Floor + | BuiltinScalarFunction::Ln + | BuiltinScalarFunction::Log10 + | BuiltinScalarFunction::Log2 + | BuiltinScalarFunction::Radians + | BuiltinScalarFunction::Signum + | BuiltinScalarFunction::Sin + | BuiltinScalarFunction::Sinh + | BuiltinScalarFunction::Sqrt + | BuiltinScalarFunction::Tan + | BuiltinScalarFunction::Tanh + | BuiltinScalarFunction::Trunc => { + // math expressions expect 1 argument of type f64 or f32 + // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we + // return the best approximation for it (in f64). + // We accept f32 because in this case it is clear that the best approximation + // will be as good as the number of digits in the number + Signature::uniform(1, vec![Float64, Float32], self.volatility()) + } + BuiltinScalarFunction::Now + | BuiltinScalarFunction::CurrentDate + | BuiltinScalarFunction::CurrentTime => { + Signature::uniform(0, vec![], self.volatility()) + } + } + } } fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { @@ -526,6 +1198,44 @@ impl FromStr for BuiltinScalarFunction { } } +macro_rules! make_utf8_to_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 => $largeUtf8Type, + DataType::Utf8 => $utf8Type, + DataType::Null => DataType::Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal(format!( + "The {:?} function can only accept strings.", + name + ))); + } + }) + } + }; +} + +make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); +make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::LargeBinary => DataType::Binary, + DataType::Null => DataType::Null, + _ => { + // this error is internal as `data_types` should have captured this. + return Err(DataFusionError::Internal(format!( + "The {name:?} function can only accept strings or binary arrays." + ))); + } + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3c68c4acd7dd..52ad65773c1a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -22,9 +22,7 @@ use crate::expr::{ }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::get_result_type; -use crate::{ - aggregate_function, function, window_function, LogicalPlan, Projection, Subquery, -}; +use crate::{aggregate_function, window_function, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::DataType; use datafusion_common::{Column, DFField, DFSchema, DataFusionError, ExprSchema, Result}; @@ -87,7 +85,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - function::return_type(fun, &data_types) + fun.return_type(&data_types) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index f47b94322d94..bd242c493e43 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,17 +17,13 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::function_err::generate_signature_error_msg; -use crate::nullif::SUPPORTED_NULLIF_TYPES; -use crate::type_coercion::functions::data_types; -use crate::ColumnarValue; -use crate::{ - conditional_expressions, struct_expressions, Accumulator, BuiltinScalarFunction, - Signature, TypeSignature, -}; -use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{DataFusionError, Result}; +use crate::{Accumulator, BuiltinScalarFunction, Signature}; +use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use arrow::datatypes::DataType; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::Result; use std::sync::Arc; +use strum::IntoEnumIterator; /// Scalar function /// @@ -54,646 +50,53 @@ pub type AccumulatorFunctionImplementation = pub type StateTypeFunction = Arc Result>> + Send + Sync>; -macro_rules! make_utf8_to_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - DataType::LargeUtf8 => $largeUtf8Type, - DataType::Utf8 => $utf8Type, - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal(format!( - "The {:?} function can only accept strings.", - name - ))); - } - }) - } - }; -} - -make_utf8_to_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); -make_utf8_to_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - -fn utf8_or_binary_to_binary_type(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::LargeBinary => DataType::Binary, - DataType::Null => DataType::Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal(format!( - "The {name:?} function can only accept strings or binary arrays." - ))); - } - }) -} - /// Returns the datatype of the scalar function +#[deprecated( + since = "27.0.0", + note = "please use `BuiltinScalarFunction::return_type` instead" +)] pub fn return_type( fun: &BuiltinScalarFunction, input_expr_types: &[DataType], ) -> Result { - use DataType::*; - use TimeUnit::*; - - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - if input_expr_types.is_empty() && !fun.supports_zero_argument() { - return Err(DataFusionError::Plan(generate_signature_error_msg( - fun, - input_expr_types, - ))); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature(fun)).map_err(|_| { - DataFusionError::Plan(generate_signature_error_msg(fun, input_expr_types)) - })?; - - // the return type of the built in function. - // Some built-in functions' return type depends on the incoming type. - match fun { - BuiltinScalarFunction::ArrayAppend => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayConcat => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept fixed size list as the args." - ))), - }, - BuiltinScalarFunction::ArrayDims => Ok(UInt8), - BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), - BuiltinScalarFunction::ArrayLength => Ok(UInt8), - BuiltinScalarFunction::ArrayNdims => Ok(UInt8), - BuiltinScalarFunction::ArrayPosition => Ok(UInt8), - BuiltinScalarFunction::ArrayPositions => Ok(UInt8), - BuiltinScalarFunction::ArrayPrepend => match &input_expr_types[1] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayRemove => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayReplace => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::ArrayToString => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::Cardinality => Ok(UInt64), - BuiltinScalarFunction::MakeArray => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), - BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept list as the first argument" - ))), - }, - BuiltinScalarFunction::Ascii => Ok(Int32), - BuiltinScalarFunction::BitLength => { - utf8_to_int_type(&input_expr_types[0], "bit_length") - } - BuiltinScalarFunction::Btrim => utf8_to_str_type(&input_expr_types[0], "btrim"), - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } - BuiltinScalarFunction::Chr => Ok(Utf8), - BuiltinScalarFunction::Coalesce => { - // COALESCE has multiple args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &signature(fun)); - coerced_types.map(|types| types[0].clone()) - } - BuiltinScalarFunction::Concat => Ok(Utf8), - BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8), - BuiltinScalarFunction::DatePart => Ok(Float64), - BuiltinScalarFunction::DateTrunc | BuiltinScalarFunction::DateBin => { - match input_expr_types[1] { - Timestamp(Nanosecond, _) | Utf8 => Ok(Timestamp(Nanosecond, None)), - Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)), - Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)), - Timestamp(Second, _) => Ok(Timestamp(Second, None)), - _ => Err(DataFusionError::Internal(format!( - "The {fun} function can only accept timestamp as the second arg." - ))), - } - } - BuiltinScalarFunction::InitCap => { - utf8_to_str_type(&input_expr_types[0], "initcap") - } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lower => utf8_to_str_type(&input_expr_types[0], "lower"), - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"), - BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"), - BuiltinScalarFunction::NullIf => { - // NULLIF has two args and they might get coerced, get a preview of this - let coerced_types = data_types(input_expr_types, &signature(fun)); - coerced_types.map(|typs| typs[0].clone()) - } - BuiltinScalarFunction::OctetLength => { - utf8_to_int_type(&input_expr_types[0], "octet_length") - } - BuiltinScalarFunction::Pi => Ok(Float64), - BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Uuid => Ok(Utf8), - BuiltinScalarFunction::RegexpReplace => { - utf8_to_str_type(&input_expr_types[0], "regex_replace") - } - BuiltinScalarFunction::Repeat => utf8_to_str_type(&input_expr_types[0], "repeat"), - BuiltinScalarFunction::Replace => { - utf8_to_str_type(&input_expr_types[0], "replace") - } - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => utf8_to_str_type(&input_expr_types[0], "right"), - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::Rtrim => utf8_to_str_type(&input_expr_types[0], "rtrimp"), - BuiltinScalarFunction::SHA224 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha224") - } - BuiltinScalarFunction::SHA256 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha256") - } - BuiltinScalarFunction::SHA384 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha384") - } - BuiltinScalarFunction::SHA512 => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "sha512") - } - BuiltinScalarFunction::Digest => { - utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") - } - BuiltinScalarFunction::SplitPart => { - utf8_to_str_type(&input_expr_types[0], "split_part") - } - BuiltinScalarFunction::StartsWith => Ok(Boolean), - BuiltinScalarFunction::Strpos => utf8_to_int_type(&input_expr_types[0], "strpos"), - BuiltinScalarFunction::Substr => utf8_to_str_type(&input_expr_types[0], "substr"), - BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The to_hex function can only accept integers.".to_string(), - )); - } - }), - BuiltinScalarFunction::ToTimestamp => Ok(Timestamp(Nanosecond, None)), - BuiltinScalarFunction::ToTimestampMillis => Ok(Timestamp(Millisecond, None)), - BuiltinScalarFunction::ToTimestampMicros => Ok(Timestamp(Microsecond, None)), - BuiltinScalarFunction::ToTimestampSeconds => Ok(Timestamp(Second, None)), - BuiltinScalarFunction::FromUnixtime => Ok(Timestamp(Second, None)), - BuiltinScalarFunction::Now => Ok(Timestamp(Nanosecond, Some("+00:00".into()))), - BuiltinScalarFunction::CurrentDate => Ok(Date32), - BuiltinScalarFunction::CurrentTime => Ok(Time64(Nanosecond)), - BuiltinScalarFunction::Translate => { - utf8_to_str_type(&input_expr_types[0], "translate") - } - BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), - BuiltinScalarFunction::Upper => utf8_to_str_type(&input_expr_types[0], "upper"), - BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] { - LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), - Utf8 => List(Arc::new(Field::new("item", Utf8, true))), - Null => Null, - _ => { - // this error is internal as `data_types` should have captured this. - return Err(DataFusionError::Internal( - "The regexp_extract function can only accept strings.".to_string(), - )); - } - }), - - BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Gcd - | BuiltinScalarFunction::Lcm => Ok(Int64), - - BuiltinScalarFunction::Power => match &input_expr_types[0] { - Int64 => Ok(Int64), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::Struct => { - let return_fields = input_expr_types - .iter() - .enumerate() - .map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true)) - .collect::>(); - Ok(Struct(Fields::from(return_fields))) - } - - BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::Log => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - - BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), - - BuiltinScalarFunction::Abs - | BuiltinScalarFunction::Acos - | BuiltinScalarFunction::Asin - | BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Cos - | BuiltinScalarFunction::Cosh - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Round - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sin - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc => match input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - } + fun.return_type(input_expr_types) } /// Return the [`Signature`] supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `BuiltinScalarFunction::signature` instead" +)] pub fn signature(fun: &BuiltinScalarFunction) -> Signature { - use DataType::*; - use IntervalUnit::*; - use TimeUnit::*; - use TypeSignature::*; - // note: the physical expression must accept the type returned by this function or the execution panics. - - // for now, the list is small, as we do not have many built-in functions. - match fun { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayConcat => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayDims => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::ArrayFill => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayLength => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayNdims => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::ArrayPosition => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayPositions => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayRemove => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::ArrayReplace => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::ArrayToString => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::Cardinality => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::MakeArray => Signature::variadic_any(fun.volatility()), - BuiltinScalarFunction::TrimArray => Signature::any(2, fun.volatility()), - BuiltinScalarFunction::Struct => Signature::variadic( - struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), - fun.volatility(), - ), - BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { - Signature::variadic(vec![Utf8], fun.volatility()) - } - BuiltinScalarFunction::Coalesce => Signature::variadic( - conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), - fun.volatility(), - ), - BuiltinScalarFunction::SHA224 - | BuiltinScalarFunction::SHA256 - | BuiltinScalarFunction::SHA384 - | BuiltinScalarFunction::SHA512 - | BuiltinScalarFunction::MD5 => Signature::uniform( - 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], - fun.volatility(), - ), - BuiltinScalarFunction::Ascii - | BuiltinScalarFunction::BitLength - | BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => { - Signature::uniform(1, vec![Utf8, LargeUtf8], fun.volatility()) - } - BuiltinScalarFunction::Btrim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim - | BuiltinScalarFunction::Trim => Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], - fun.volatility(), - ), - BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { - Signature::uniform(1, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), - Exact(vec![Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, Int64, LargeUtf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Left - | BuiltinScalarFunction::Repeat - | BuiltinScalarFunction::Right => Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestamp => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampMillis => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampMicros => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::ToTimestampSeconds => Signature::uniform( - 1, - vec![ - Int64, - Timestamp(Nanosecond, None), - Timestamp(Microsecond, None), - Timestamp(Millisecond, None), - Timestamp(Second, None), - Utf8, - ], - fun.volatility(), - ), - BuiltinScalarFunction::FromUnixtime => { - Signature::uniform(1, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::Digest => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::DateTrunc => { - Signature::exact(vec![Utf8, Timestamp(Nanosecond, None)], fun.volatility()) - } - BuiltinScalarFunction::DateBin => { - let base_sig = |array_type: TimeUnit| { - vec![ - Exact(vec![ - Interval(MonthDayNano), - Timestamp(array_type.clone(), None), - Timestamp(Nanosecond, None), - ]), - Exact(vec![ - Interval(DayTime), - Timestamp(array_type.clone(), None), - Timestamp(Nanosecond, None), - ]), - Exact(vec![ - Interval(MonthDayNano), - Timestamp(array_type.clone(), None), - ]), - Exact(vec![Interval(DayTime), Timestamp(array_type, None)]), - ] - }; - - let full_sig = [Nanosecond, Microsecond, Millisecond, Second] - .into_iter() - .map(base_sig) - .collect::>() - .concat(); - - Signature::one_of(full_sig, fun.volatility()) - } - BuiltinScalarFunction::DatePart => Signature::one_of( - vec![ - Exact(vec![Utf8, Date32]), - Exact(vec![Utf8, Date64]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8, Timestamp(Nanosecond, Some("+00:00".into()))]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::SplitPart => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - fun.volatility(), - ), - - BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - fun.volatility(), - ) - } - - BuiltinScalarFunction::Substr => Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), - ], - fun.volatility(), - ), + fun.signature() +} - BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { - Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], fun.volatility()) - } - BuiltinScalarFunction::RegexpReplace => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8, Utf8]), - ], - fun.volatility(), - ), +/// Suggest a valid function based on an invalid input function name +pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { + let valid_funcs = if is_window_func { + // All aggregate functions and builtin window functions + AggregateFunction::iter() + .map(|func| func.to_string()) + .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) + .collect() + } else { + // All scalar functions and aggregate functions + BuiltinScalarFunction::iter() + .map(|func| func.to_string()) + .chain(AggregateFunction::iter().map(|func| func.to_string())) + .collect() + }; + find_closest_match(valid_funcs, input_function_name) +} - BuiltinScalarFunction::NullIf => { - Signature::uniform(2, SUPPORTED_NULLIF_TYPES.to_vec(), fun.volatility()) - } - BuiltinScalarFunction::RegexpMatch => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Pi => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Uuid => Signature::exact(vec![], fun.volatility()), - BuiltinScalarFunction::Power => Signature::one_of( - vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], - fun.volatility(), - ), - BuiltinScalarFunction::Round => Signature::one_of( - vec![ - Exact(vec![Float64, Int64]), - Exact(vec![Float32, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - fun.volatility(), - ), - BuiltinScalarFunction::Log => Signature::one_of( - vec![ - Exact(vec![Float32]), - Exact(vec![Float64]), - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), - ], - fun.volatility(), - ), - BuiltinScalarFunction::Factorial => { - Signature::uniform(1, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { - Signature::uniform(2, vec![Int64], fun.volatility()) - } - BuiltinScalarFunction::ArrowTypeof => Signature::any(1, fun.volatility()), - BuiltinScalarFunction::Abs - | BuiltinScalarFunction::Acos - | BuiltinScalarFunction::Asin - | BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Cbrt - | BuiltinScalarFunction::Ceil - | BuiltinScalarFunction::Cos - | BuiltinScalarFunction::Cosh - | BuiltinScalarFunction::Degrees - | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 - | BuiltinScalarFunction::Radians - | BuiltinScalarFunction::Signum - | BuiltinScalarFunction::Sin - | BuiltinScalarFunction::Sinh - | BuiltinScalarFunction::Sqrt - | BuiltinScalarFunction::Tan - | BuiltinScalarFunction::Tanh - | BuiltinScalarFunction::Trunc => { - // math expressions expect 1 argument of type f64 or f32 - // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we - // return the best approximation for it (in f64). - // We accept f32 because in this case it is clear that the best approximation - // will be as good as the number of digits in the number - Signature::uniform(1, vec![Float64, Float32], fun.volatility()) - } - BuiltinScalarFunction::Now - | BuiltinScalarFunction::CurrentDate - | BuiltinScalarFunction::CurrentTime => { - Signature::uniform(0, vec![], fun.volatility()) - } - } +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Input `candidates` must not be empty otherwise it will panic +fn find_closest_match(candidates: Vec, target: &str) -> String { + let target = target.to_lowercase(); + candidates + .into_iter() + .min_by_key(|candidate| { + datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) + }) + .expect("No candidates provided.") // Panic if `candidates` argument is empty } diff --git a/datafusion/expr/src/function_err.rs b/datafusion/expr/src/function_err.rs deleted file mode 100644 index 1635ac3b0c8c..000000000000 --- a/datafusion/expr/src/function_err.rs +++ /dev/null @@ -1,125 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Function_err module enhances frontend error messages for unresolved functions due to incorrect parameters, -//! by providing the correct function signatures. -//! -//! For example, a query like `select round(3.14, 1.1);` would yield: -//! ```text -//! Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. -//! Candidate functions: -//! round(Float64, Int64) -//! round(Float32, Int64) -//! round(Float64) -//! round(Float32) -//! ``` - -use crate::function::signature; -use crate::{ - AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, TypeSignature, -}; -use arrow::datatypes::DataType; -use datafusion_common::utils::datafusion_strsim; -use strum::IntoEnumIterator; - -impl TypeSignature { - fn to_string_repr(&self) -> Vec { - match self { - TypeSignature::Variadic(types) => { - vec![format!("{}, ..", join_types(types, "/"))] - } - TypeSignature::Uniform(arg_count, valid_types) => { - vec![std::iter::repeat(join_types(valid_types, "/")) - .take(*arg_count) - .collect::>() - .join(", ")] - } - TypeSignature::Exact(types) => { - vec![join_types(types, ", ")] - } - TypeSignature::Any(arg_count) => { - vec![std::iter::repeat("Any") - .take(*arg_count) - .collect::>() - .join(", ")] - } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], - TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], - TypeSignature::OneOf(sigs) => { - sigs.iter().flat_map(|s| s.to_string_repr()).collect() - } - } - } -} - -/// Helper function to join types with specified delimiter. -fn join_types(types: &[T], delimiter: &str) -> String { - types - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(delimiter) -} - -/// Creates a detailed error message for a function with wrong signature. -pub fn generate_signature_error_msg( - fun: &BuiltinScalarFunction, - input_expr_types: &[DataType], -) -> String { - let candidate_signatures = signature(fun) - .type_signature - .to_string_repr() - .iter() - .map(|args_str| format!("\t{fun}({args_str})")) - .collect::>() - .join("\n"); - - format!( - "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", - fun, join_types(input_expr_types, ", "), candidate_signatures - ) -} - -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) -/// Input `candidates` must not be empty otherwise it will panic -fn find_closest_match(candidates: Vec, target: &str) -> String { - let target = target.to_lowercase(); - candidates - .into_iter() - .min_by_key(|candidate| { - datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) - }) - .expect("No candidates provided.") // Panic if `candidates` argument is empty -} - -/// Suggest a valid function based on an invalid input function name -pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { - let valid_funcs = if is_window_func { - // All aggregate functions and builtin window functions - AggregateFunction::iter() - .map(|func| func.to_string()) - .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) - .collect() - } else { - // All scalar functions and aggregate functions - BuiltinScalarFunction::iter() - .map(|func| func.to_string()) - .chain(AggregateFunction::iter().map(|func| func.to_string())) - .collect() - }; - find_closest_match(valid_funcs, input_function_name) -} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5945480aba1d..1675afb9c98a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -37,7 +37,6 @@ pub mod expr_rewriter; pub mod expr_schema; pub mod field_util; pub mod function; -pub mod function_err; mod literal; pub mod logical_plan; mod nullif; diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index a2caba4fb8bb..e4ffd74d8daa 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -60,6 +60,48 @@ pub enum TypeSignature { OneOf(Vec), } +impl TypeSignature { + pub(crate) fn to_string_repr(&self) -> Vec { + match self { + TypeSignature::Variadic(types) => { + vec![format!("{}, ..", Self::join_types(types, "/"))] + } + TypeSignature::Uniform(arg_count, valid_types) => { + vec![std::iter::repeat(Self::join_types(valid_types, "/")) + .take(*arg_count) + .collect::>() + .join(", ")] + } + TypeSignature::Exact(types) => { + vec![Self::join_types(types, ", ")] + } + TypeSignature::Any(arg_count) => { + vec![std::iter::repeat("Any") + .take(*arg_count) + .collect::>() + .join(", ")] + } + TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], + TypeSignature::OneOf(sigs) => { + sigs.iter().flat_map(|s| s.to_string_repr()).collect() + } + } + } + + /// Helper function to join types with specified delimiter. + pub(crate) fn join_types( + types: &[T], + delimiter: &str, + ) -> String { + types + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(delimiter) + } +} + /// The signature of a function defines the supported argument types /// and its volatility. #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0d0061a5e435..3ee6a2401b02 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -40,8 +40,8 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_utf8}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ - aggregate_function, function, is_false, is_not_false, is_not_true, is_not_unknown, - is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, + aggregate_function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, + is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; @@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let nex_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, - &function::signature(&fun), + &fun.signature(), )?; let expr = Expr::ScalarFunction(ScalarFunction::new(fun, nex_expr)); Ok(expr) diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 648dd4a144c6..5a5bdf4702e0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -45,7 +45,7 @@ use arrow::{ }; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - function, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, }; use std::sync::Arc; @@ -62,7 +62,7 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - let data_type = function::return_type(fun, &input_expr_types)?; + let data_type = fun.return_type(&input_expr_types)?; let fun_expr: ScalarFunctionImplementation = match fun { // These functions need args and input schema to pick an implementation @@ -2921,7 +2921,7 @@ mod tests { execution_props: &ExecutionProps, ) -> Result> { let type_coerced_phy_exprs = - coerce(input_phy_exprs, input_schema, &function::signature(fun)).unwrap(); + coerce(input_phy_exprs, input_schema, &fun.signature()).unwrap(); create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, execution_props) } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 104a65832dcd..0fb6b7554776 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,7 +18,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; -use datafusion_expr::function_err::suggest_valid_function; +use datafusion_expr::function::suggest_valid_function; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{