From 37b1bda061bc442e47c7cb701962672f6c3b147c Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:07:20 -0800 Subject: [PATCH 01/45] Make DFSchema::datatype_is_semantically_equal public (#13429) --- datafusion/common/src/dfschema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index aa2d93989da19..e893cee089c93 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -686,7 +686,7 @@ impl DFSchema { /// name and type), ignoring both metadata and nullability. /// /// request to upstream: - fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool { + pub fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool { // check nested fields match (dt1, dt2) { (DataType::Dictionary(k1, v1), DataType::Dictionary(k2, v2)) => { From c51b432089bd7c6084650c59ecc8f017fe4debf7 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 15 Nov 2024 14:18:22 -0500 Subject: [PATCH 02/45] Add support for utf8view to nvl function (#13382) * Directly support utf8view in nvl. #13381 * Fix what looks like a merge error. --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/core/nvl.rs | 1 + datafusion/sqllogictest/test_files/nvl.slt | 30 +++++++++++++++++++ .../test_files/string/string_view.slt | 8 +++++ 3 files changed, 39 insertions(+) diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 16438e1b6254f..24b6f5fc14fef 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -47,6 +47,7 @@ static SUPPORTED_NVL_TYPES: &[DataType] = &[ DataType::Int64, DataType::Float32, DataType::Float64, + DataType::Utf8View, DataType::Utf8, DataType::LargeUtf8, ]; diff --git a/datafusion/sqllogictest/test_files/nvl.slt b/datafusion/sqllogictest/test_files/nvl.slt index 81e79e1eb5b06..daab54307cc20 100644 --- a/datafusion/sqllogictest/test_files/nvl.slt +++ b/datafusion/sqllogictest/test_files/nvl.slt @@ -118,3 +118,33 @@ query I SELECT NVL(NULL, NULL); ---- NULL + +query T +SELECT NVL(arrow_cast(text_field, 'Utf8View'), 'zxb') FROM test; +---- +abc +def +ghij +zxb +zxc +zxb + +query T +SELECT NVL(arrow_cast('a', 'Utf8View'), 'zxb'); +---- +a + +query T +SELECT NVL('zxb', arrow_cast('a', 'Utf8View')); +---- +zxb + +query T +SELECT NVL(NULL, arrow_cast('a', 'Utf8View')); +---- +a + +query T +SELECT NVL(arrow_cast('a', 'Utf8View'), NULL); +---- +a diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 2b44c86f52d83..12295a01a9f1e 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -935,6 +935,14 @@ logical_plan 01)Projection: to_timestamp(test.column1_utf8view, Utf8("a,b,c,d")) AS c 02)--TableScan: test projection=[column1_utf8view] +## Ensure no casts for NVL +query TT +EXPLAIN SELECT NVL(column1_utf8view, 'a') as c2 FROM test; +---- +logical_plan +01)Projection: nvl(test.column1_utf8view, Utf8View("a")) AS c2 +02)--TableScan: test projection=[column1_utf8view] + ## Ensure no casts for nullif query TT EXPLAIN SELECT From d840e987cd855fbbdc5d3e5d69683a5b4f279bb6 Mon Sep 17 00:00:00 2001 From: Sherin Jacob Date: Sat, 16 Nov 2024 00:49:15 +0530 Subject: [PATCH 03/45] fix: serialize user-defined window functions to proto (#13421) * Adds roundtrip physical plan test * Adds enum for udwf to `WindowFunction` * initial fix for serializing udwf * Revives deleted test * Adds codec methods for physical plan * Rewrite error message * Minor: rename binding + formatting fixes * Extends `PhysicalExtensionCodec` for udwf * Minor: formatting * Restricts visibility to tests --- datafusion/physical-plan/src/windows/mod.rs | 8 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 13 ++ datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/physical_plan/from_proto.rs | 6 + datafusion/proto/src/physical_plan/mod.rs | 10 +- .../proto/src/physical_plan/to_proto.rs | 25 ++- datafusion/proto/tests/cases/mod.rs | 60 ++++++- .../tests/cases/roundtrip_physical_plan.rs | 160 +++++++++++++++++- 9 files changed, 272 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index a323a958cc76f..32173c3ef17df 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -194,7 +194,7 @@ pub fn create_udwf_window_expr( /// Implements [`BuiltInWindowFunctionExpr`] for [`WindowUDF`] #[derive(Clone, Debug)] -struct WindowUDFExpr { +pub struct WindowUDFExpr { fun: Arc, args: Vec>, /// Display name @@ -209,6 +209,12 @@ struct WindowUDFExpr { ignore_nulls: bool, } +impl WindowUDFExpr { + pub fn fun(&self) -> &Arc { + &self.fun + } +} + impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn as_any(&self) -> &dyn std::any::Any { self diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6606b1e93f02f..504e5e1ceead7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -853,6 +853,7 @@ message PhysicalWindowExprNode { oneof window_function { // BuiltInWindowFunction built_in_function = 2; string user_defined_aggr_function = 3; + string user_defined_window_function = 10; } repeated PhysicalExprNode args = 4; repeated PhysicalExprNode partition_by = 5; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 09c873b1f98ad..29920814a8025 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16326,6 +16326,9 @@ impl serde::Serialize for PhysicalWindowExprNode { physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => { struct_ser.serialize_field("userDefinedAggrFunction", v)?; } + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(v) => { + struct_ser.serialize_field("userDefinedWindowFunction", v)?; + } } } struct_ser.end() @@ -16350,6 +16353,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "funDefinition", "user_defined_aggr_function", "userDefinedAggrFunction", + "user_defined_window_function", + "userDefinedWindowFunction", ]; #[allow(clippy::enum_variant_names)] @@ -16361,6 +16366,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { Name, FunDefinition, UserDefinedAggrFunction, + UserDefinedWindowFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16389,6 +16395,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "name" => Ok(GeneratedField::Name), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), + "userDefinedWindowFunction" | "user_defined_window_function" => Ok(GeneratedField::UserDefinedWindowFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16461,6 +16468,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedAggrFunction); } + GeneratedField::UserDefinedWindowFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedWindowFunction")); + } + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedWindowFunction); + } } } Ok(PhysicalWindowExprNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ad5320fc657c5..07090b7cba110 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1266,7 +1266,7 @@ pub struct PhysicalWindowExprNode { pub name: ::prost::alloc::string::String, #[prost(bytes = "vec", optional, tag = "9")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "3")] + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "3, 10")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, @@ -1278,6 +1278,8 @@ pub mod physical_window_expr_node { /// BuiltInWindowFunction built_in_function = 2; #[prost(string, tag = "3")] UserDefinedAggrFunction(::prost::alloc::string::String), + #[prost(string, tag = "10")] + UserDefinedWindowFunction(::prost::alloc::string::String), } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 1c5bdd0c02ba5..e528b38b84a8c 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -152,6 +152,12 @@ pub fn parse_physical_window_expr( None => registry.udaf(udaf_name)? }) } + protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { + WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { + Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, + None => registry.udwf(udwf_name)? + }) + } } } else { return Err(proto_error("Missing required field in protobuf")); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 64e462d1695fd..292ce13d0eded 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -64,7 +64,7 @@ use datafusion::physical_plan::{ ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::common::{byte_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ @@ -2119,6 +2119,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided for window function {name}") + } + + fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 60dcd650191d6..7d9a524af8288 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -19,14 +19,14 @@ use std::sync::Arc; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; -use datafusion::physical_plan::windows::PlainAggregateWindowExpr; +use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion::{ datasource::{ @@ -68,7 +68,7 @@ pub fn serialize_physical_aggr_expr( ordering_req, distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), - fun_definition: (!buf.is_empty()).then_some(buf) + fun_definition: (!buf.is_empty()).then_some(buf), }, )), }) @@ -120,6 +120,25 @@ pub fn serialize_physical_window_expr( window_frame, codec, )? + } else if let Some(built_in_window_expr) = expr.downcast_ref::() { + if let Some(expr) = built_in_window_expr + .get_built_in_func_expr() + .as_any() + .downcast_ref::() + { + let mut buf = Vec::new(); + codec.try_encode_udwf(expr.fun(), &mut buf)?; + ( + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( + expr.fun().name().to_string(), + ), + (!buf.is_empty()).then_some(buf), + ) + } else { + return not_impl_err!( + "User-defined window function not supported: {window_expr:?}" + ); + } } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index fbb2cd8f1e832..4d69ca075483b 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::{DataType, Field}; use std::any::Any; - -use arrow::datatypes::DataType; +use std::fmt::Debug; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl, + Signature, Volatility, WindowUDFImpl, }; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; mod roundtrip_logical_plan; mod roundtrip_physical_plan; @@ -125,3 +128,54 @@ pub struct MyAggregateUdfNode { #[prost(string, tag = "1")] pub result: String, } + +#[derive(Debug)] +pub(in crate::cases) struct CustomUDWF { + signature: Signature, + payload: String, +} + +impl CustomUDWF { + pub fn new(payload: String) -> Self { + Self { + signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable), + payload, + } + } +} + +impl WindowUDFImpl for CustomUDWF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "custom_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(CustomUDWFEvaluator {})) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } +} + +#[derive(Debug)] +struct CustomUDWFEvaluator; + +impl PartitionEvaluator for CustomUDWFEvaluator {} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub(in crate::cases) struct CustomUDWFNode { + #[prost(string, tag = "1")] + pub payload: String, +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index aab63dd8bd66a..efa462aa7a855 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -32,7 +32,10 @@ use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; -use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; +use crate::cases::{ + CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, + MyRegexUdfNode, +}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -47,9 +50,11 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::functions_window::nth_value::nth_value_udwf; +use datafusion::functions_window::row_number::row_number_udwf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; -use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{ LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr, }; @@ -73,8 +78,13 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; -use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowAggExec}; -use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr, Statistics}; +use datafusion::physical_plan::windows::{ + create_udwf_window_expr, BoundedWindowAggExec, PlainAggregateWindowExpr, + WindowAggExec, +}; +use datafusion::physical_plan::{ + ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics, +}; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; @@ -87,7 +97,7 @@ use datafusion_common::{ }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -263,12 +273,74 @@ fn roundtrip_nested_loop_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_udwf() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &schema, + "row_number() PARTITION BY [a] ORDER BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?, + &[ + col("a", &schema)? + ], + &LexOrdering::new(vec![ + PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)), + ]), + Arc::new(WindowFrame::new(None)), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + vec![col("a", &schema)?], + InputOrderMode::Sorted, + )?)) +} + #[test] fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + let nth_value_window = + create_udwf_window_expr( + &nth_value_udwf(), + &[col("a", &schema)?, + lit(2)], schema.as_ref(), + "NTH_VALUE(a, 2) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?; + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + nth_value_window, + &[col("b", &schema)?], + &LexOrdering { + inner: vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + }, + Arc::new(window_frame), + )); + let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( AggregateExprBuilder::new( avg_udaf(), @@ -306,7 +378,7 @@ fn roundtrip_window() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( - vec![plain_aggr_window_expr, sliding_aggr_window_expr], + vec![plain_aggr_window_expr, sliding_aggr_window_expr, udwf_expr], input, vec![col("b", &schema)?], )?)) @@ -948,6 +1020,33 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { } Ok(()) } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "custom_udwf" { + let proto = CustomUDWFNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode custom_udwf: {err}")) + })?; + + Ok(Arc::new(WindowUDF::from(CustomUDWF::new(proto.payload)))) + } else { + not_impl_err!( + "unrecognized user-defined window function implementation, cannot decode" + ) + } + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udwf) = binding.as_any().downcast_ref::() { + let proto = CustomUDWFNode { + payload: udwf.payload.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udwf: {err:?}")) + })?; + } + Ok(()) + } } #[test] @@ -1005,6 +1104,55 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_udwf_extension_codec() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let custom_udwf = Arc::new(WindowUDF::from(CustomUDWF::new("payload".to_string()))); + let udwf = create_udwf_window_expr( + &custom_udwf, + &[col("a", &schema)?], + schema.as_ref(), + "custom_udwf(a) PARTITION BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?; + + let window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + udwf, + &[col("b", &schema)?], + &LexOrdering { + inner: vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + }, + Arc::new(window_frame), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + let window = Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + vec![col("b", &schema)?], + InputOrderMode::Sorted, + )?); + + let ctx = SessionContext::new(); + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + Ok(()) +} + #[test] fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let field_text = Field::new("text", DataType::Utf8, true); From 7e69580032b138aadf1cf6975cd2004916885bd5 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 15 Nov 2024 14:41:45 -0500 Subject: [PATCH 04/45] Add support for Utf8View to crypto functions #13406 (#13407) --- datafusion/functions/src/crypto/basic.rs | 57 ++++++++++++++---- datafusion/functions/src/crypto/digest.rs | 1 + datafusion/functions/src/crypto/md5.rs | 4 +- datafusion/functions/src/crypto/sha224.rs | 2 +- datafusion/functions/src/crypto/sha256.rs | 2 +- datafusion/functions/src/crypto/sha384.rs | 2 +- datafusion/functions/src/crypto/sha512.rs | 2 +- datafusion/sqllogictest/test_files/expr.slt | 5 ++ .../test_files/string/string_view.slt | 60 +++++++++++++++++++ 9 files changed, 116 insertions(+), 19 deletions(-) diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index 716afd84a9c9c..74dc5d517c2ba 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -17,17 +17,18 @@ //! "crypto" DataFusion functions -use arrow::array::StringArray; use arrow::array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait}; +use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; use datafusion_common::cast::as_binary_array; +use arrow::compute::StringArrayType; use datafusion_common::plan_err; use datafusion_common::{ - cast::{as_generic_binary_array, as_generic_string_array}, - exec_err, internal_err, DataFusionError, Result, ScalarValue, + cast::as_generic_binary_array, exec_err, internal_err, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::ColumnarValue; use md5::Md5; @@ -121,9 +122,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result { } let digest_algorithm = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ScalarValue::Utf8View(Some(method)) + | ScalarValue::Utf8(Some(method)) + | ScalarValue::LargeUtf8(Some(method)) => method.parse::(), other => exec_err!("Unsupported data type {other:?} for function digest"), }, ColumnarValue::Array(_) => { @@ -132,6 +133,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result { }?; digest_process(&args[0], digest_algorithm) } + impl FromStr for DigestAlgorithm { type Err = DataFusionError; fn from_str(name: &str) -> Result { @@ -166,12 +168,14 @@ impl FromStr for DigestAlgorithm { }) } } + impl fmt::Display for DigestAlgorithm { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", format!("{self:?}").to_lowercase()) } } -// /// computes md5 hash digest of the given input + +/// computes md5 hash digest of the given input pub fn md5(args: &[ColumnarValue]) -> Result { if args.len() != 1 { return exec_err!( @@ -180,7 +184,9 @@ pub fn md5(args: &[ColumnarValue]) -> Result { DigestAlgorithm::Md5 ); } + let value = digest_process(&args[0], DigestAlgorithm::Md5)?; + // md5 requires special handling because of its unique utf8 return type Ok(match value { ColumnarValue::Array(array) => { @@ -214,7 +220,8 @@ pub fn utf8_or_binary_to_binary_type( name: &str, ) -> Result { Ok(match arg_type { - DataType::LargeUtf8 + DataType::Utf8View + | DataType::LargeUtf8 | DataType::Utf8 | DataType::Binary | DataType::LargeBinary => DataType::Binary, @@ -296,8 +303,30 @@ impl DigestAlgorithm { where T: OffsetSizeTrait, { - let input_value = as_generic_string_array::(value)?; - let array: ArrayRef = match self { + let array = match value.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + let v = value.as_string::(); + self.digest_utf8_array_impl::<&GenericStringArray>(v) + } + DataType::Utf8View => { + let v = value.as_string_view(); + self.digest_utf8_array_impl::<&StringViewArray>(v) + } + other => { + return exec_err!("unsupported type for digest_utf_array: {other:?}") + } + }; + Ok(ColumnarValue::Array(array)) + } + + pub fn digest_utf8_array_impl<'a, StringArrType>( + self, + input_value: StringArrType, + ) -> ArrayRef + where + StringArrType: StringArrayType<'a>, + { + match self { Self::Md5 => digest_to_array!(Md5, input_value), Self::Sha224 => digest_to_array!(Sha224, input_value), Self::Sha256 => digest_to_array!(Sha256, input_value), @@ -318,8 +347,7 @@ impl DigestAlgorithm { .collect(); Arc::new(binary_array) } - }; - Ok(ColumnarValue::Array(array)) + } } } pub fn digest_process( @@ -328,6 +356,7 @@ pub fn digest_process( ) -> Result { match value { ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => digest_algorithm.digest_utf8_array::(a.as_ref()), DataType::Utf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), DataType::LargeUtf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), DataType::Binary => digest_algorithm.digest_binary_array::(a.as_ref()), @@ -339,7 +368,9 @@ pub fn digest_process( ), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ScalarValue::Utf8View(a) + | ScalarValue::Utf8(a) + | ScalarValue::LargeUtf8(a) => { Ok(digest_algorithm .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) } diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index 0e43fb7785dfd..f738c6e3e40f2 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -42,6 +42,7 @@ impl DigestFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Utf8View]), Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, Utf8]), Exact(vec![Binary, Utf8]), diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 062d63bcc0182..0f18fd47b4cf0 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -42,7 +42,7 @@ impl Md5Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } @@ -65,7 +65,7 @@ impl ScalarUDFImpl for Md5Func { use DataType::*; Ok(match &arg_types[0] { LargeUtf8 | LargeBinary => LargeUtf8, - Utf8 | Binary => Utf8, + Utf8View | Utf8 | Binary => Utf8, Null => Null, Dictionary(_, t) => match **t { LargeUtf8 | LargeBinary => LargeUtf8, diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index 39202d5bf6914..f0bfcb9fab3ba 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -43,7 +43,7 @@ impl SHA224Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 74deb3fc6caad..0a0044f72206f 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -42,7 +42,7 @@ impl SHA256Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index 9b1e1ba9ec3cb..7f8220e5f9d5f 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -42,7 +42,7 @@ impl SHA384Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index c88579fd08eea..d2d51bfa53abf 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -42,7 +42,7 @@ impl SHA512Func { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8, Binary, LargeBinary], + vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], Volatility::Immutable, ), } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index c653113fd438e..15bf771c65271 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2225,6 +2225,11 @@ SELECT digest('','blake3'); ---- af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262 +# vverify utf8view +query ? +SELECT sha224(arrow_cast('tom', 'Utf8View')); +---- +0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d query T SELECT substring('alphabet', 1) diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 12295a01a9f1e..5a08f3f5447a5 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -963,6 +963,66 @@ logical_plan 01)Projection: nullif(test.column1_utf8view, test.column1_utf8view) AS c 02)--TableScan: test projection=[column1_utf8view] +## Ensure no casts for md5 +query TT +EXPLAIN SELECT + md5(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: md5(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for sha224 +query TT +EXPLAIN SELECT + sha224(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: sha224(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for sha256 +query TT +EXPLAIN SELECT + sha256(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: sha256(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for sha384 +query TT +EXPLAIN SELECT + sha384(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: sha384(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for sha512 +query TT +EXPLAIN SELECT + sha512(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: sha512(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for digest +query TT +EXPLAIN SELECT + digest(column1_utf8view, 'md5') as c +FROM test; +---- +logical_plan +01)Projection: digest(test.column1_utf8view, Utf8View("md5")) AS c +02)--TableScan: test projection=[column1_utf8view] + ## Ensure no casts for binary operators # `~` operator (regex match) query TT From 1e96a0a76ca60364aa74d9a8bd8a4c15efdfb9de Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 Nov 2024 14:53:20 -0500 Subject: [PATCH 05/45] Fix `concat` simplifier for Utf8View types (#13346) * Add string view options to concat, fix simplifier for handling concat to return the same schema as without * Set coersion ordering * Add to simplification unit test to catch changes in type for concat * Update coersion ordering * Simplify computing merged type for concat --- .../core/tests/expr_api/simplification.rs | 24 +++-- datafusion/functions/src/string/concat.rs | 87 ++++++++++++++++--- 2 files changed, 95 insertions(+), 16 deletions(-) diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 68785b7a5a45c..1e6ff8088d0af 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -483,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef { Field::new("c2", DataType::Boolean, true), Field::new("c3", DataType::Int64, true), Field::new("c4", DataType::UInt32, true), + Field::new("c5", DataType::Utf8View, true), Field::new("c1_non_null", DataType::Utf8, false), Field::new("c2_non_null", DataType::Boolean, false), Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5_non_null", DataType::Utf8View, false), ]) .to_dfschema_ref() .unwrap() @@ -665,20 +667,32 @@ fn test_simplify_concat_ws_with_null() { } #[test] -fn test_simplify_concat() { +fn test_simplify_concat() -> Result<()> { + let schema = expr_test_schema(); let null = lit(ScalarValue::Utf8(None)); let expr = concat(vec![ null.clone(), - col("c0"), + col("c1"), lit("hello "), null.clone(), lit("rust"), - col("c1"), + lit(ScalarValue::Utf8View(Some("!".to_string()))), + col("c2"), lit(""), null, + col("c5"), ]); - let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); - test_simplify(expr, expected) + let expr_datatype = expr.get_type(schema.as_ref())?; + let expected = concat(vec![ + col("c1"), + lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))), + col("c2"), + col("c5"), + ]); + let expected_datatype = expected.get_type(schema.as_ref())?; + assert_eq!(expr_datatype, expected_datatype); + test_simplify(expr, expected); + Ok(()) } #[test] fn test_simplify_cycles() { diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index f1e60004ddd00..d49a2777b4ff8 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -48,7 +48,7 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8, Utf8View, LargeUtf8], + vec![Utf8View, Utf8, LargeUtf8], Volatility::Immutable, ), } @@ -110,8 +110,19 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => { + result.push_str(v); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => plan_err!( + "Concat function does not support scalar type {:?}", + other + )?, } } @@ -282,15 +293,37 @@ pub fn simplify_concat(args: Vec) -> Result { let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); + let return_type = { + let data_types: Vec<_> = args + .iter() + .filter_map(|expr| match expr { + Expr::Literal(l) => Some(l.data_type()), + _ => None, + }) + .collect(); + ConcatFunc::new().return_type(&data_types) + }?; + for arg in args.clone() { match arg { + Expr::Literal(ScalarValue::Utf8(None)) => {} + Expr::Literal(ScalarValue::LargeUtf8(None)) => { + } + Expr::Literal(ScalarValue::Utf8View(None)) => { } + // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), - ) => contiguous_scalar += &v, + Expr::Literal(ScalarValue::Utf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(x) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." @@ -301,7 +334,12 @@ pub fn simplify_concat(args: Vec) -> Result { // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => unreachable!(), + } contiguous_scalar = "".to_string(); } new_args.push(arg); @@ -310,7 +348,16 @@ pub fn simplify_concat(args: Vec) -> Result { } if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match return_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => { + new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) + } + DataType::Utf8View => { + new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) + } + _ => unreachable!(), + } } if !args.eq(&new_args) { @@ -392,6 +439,17 @@ mod tests { LargeUtf8, LargeStringArray ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } @@ -406,11 +464,18 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; + let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("a"), + None, + Some("b"), + ]))); + let args = &[c0, c1, c2, c3, c4]; let result = ConcatFunc::new().invoke_batch(args, 3)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) + as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array); From 5ea1d31ca6da7136ee5e9786f817c6d5baff5f13 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 16 Nov 2024 04:06:39 +0800 Subject: [PATCH 06/45] Add sort integration benchmark (#13306) * Add sort integration benchmark * clippy * review --- benchmarks/README.md | 24 +++ benchmarks/bench.sh | 18 ++ benchmarks/src/bin/dfbench.rs | 4 +- benchmarks/src/lib.rs | 1 + benchmarks/src/sort_tpch.rs | 320 ++++++++++++++++++++++++++++++++++ 5 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 benchmarks/src/sort_tpch.rs diff --git a/benchmarks/README.md b/benchmarks/README.md index a9aa1afb97a1c..cccd7f44f5047 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -330,6 +330,30 @@ steps. The tests sort the entire dataset using several different sort orders. +## Sort TPCH + +Test performance of end-to-end sort SQL queries. (While the `Sort` benchmark focuses on a single sort executor, this benchmark tests how sorting is executed across multiple CPU cores by benchmarking sorting the whole relational table.) + +Sort integration benchmark runs whole table sort queries on TPCH `lineitem` table, with different characteristics. For example, different number of sort keys, different sort key cardinality, different number of payload columns, etc. + +See [`sort_tpch.rs`](src/sort_tpch.rs) for more details. + +### Sort TPCH Benchmark Example Runs +1. Run all queries with default setting: +```bash + cargo run --release --bin dfbench -- sort-tpch -p '....../datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' +``` + +2. Run a specific query: +```bash + cargo run --release --bin dfbench -- sort-tpch -p '....../datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' --query 2 +``` + +3. Run all queries with `bench.sh` script: +```bash +./bench.sh run sort_tpch +``` + ## IMDB Run Join Order Benchmark (JOB) on IMDB dataset. diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 47c5d1261605b..b02bfee2454e2 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -75,6 +75,7 @@ tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory parquet: Benchmark of parquet reader's filtering speed sort: Benchmark of sorting speed +sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPCH dataset clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) @@ -175,6 +176,10 @@ main() { # same data as for tpch data_tpch "1" ;; + sort_tpch) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -252,6 +257,9 @@ main() { external_aggr) run_external_aggr ;; + sort_tpch) + run_sort_tpch + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -549,6 +557,16 @@ run_external_aggr() { $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" } +# Runs the sort integration benchmark +run_sort_tpch() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/sort_tpch.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sort tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" +} + compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index f7b84116e793a..81aa5437dd5f5 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -33,7 +33,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -use datafusion_benchmarks::{clickbench, imdb, parquet_filter, sort, tpch}; +use datafusion_benchmarks::{clickbench, imdb, parquet_filter, sort, sort_tpch, tpch}; #[derive(Debug, StructOpt)] #[structopt(about = "benchmark command")] @@ -43,6 +43,7 @@ enum Options { Clickbench(clickbench::RunOpt), ParquetFilter(parquet_filter::RunOpt), Sort(sort::RunOpt), + SortTpch(sort_tpch::RunOpt), Imdb(imdb::RunOpt), } @@ -57,6 +58,7 @@ pub async fn main() -> Result<()> { Options::Clickbench(opt) => opt.run().await, Options::ParquetFilter(opt) => opt.run().await, Options::Sort(opt) => opt.run().await, + Options::SortTpch(opt) => opt.run().await, Options::Imdb(opt) => opt.run().await, } } diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 02410e0cfa01e..2d37d78764d78 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -20,5 +20,6 @@ pub mod clickbench; pub mod imdb; pub mod parquet_filter; pub mod sort; +pub mod sort_tpch; pub mod tpch; pub mod util; diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs new file mode 100644 index 0000000000000..4b83b3b8889ac --- /dev/null +++ b/benchmarks/src/sort_tpch.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides integration benchmark for sort operation. +//! It will run different sort SQL queries on TPCH `lineitem` parquet dataset. +//! +//! Another `Sort` benchmark focus on single core execution. This benchmark +//! runs end-to-end sort queries and test the performance on multiple CPU cores. + +use futures::StreamExt; +use std::path::PathBuf; +use std::sync::Arc; +use structopt::StructOpt; + +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::execution::runtime_env::RuntimeConfig; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{displayable, execute_stream}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::DEFAULT_PARQUET_EXTENSION; + +use crate::util::{BenchmarkRun, CommonOpt}; + +#[derive(Debug, StructOpt)] +pub struct RunOpt { + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Sort query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Path to data files (lineitem). Only parquet format is supported + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Path to JSON benchmark result to be compare using `compare.py` + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +impl RunOpt { + const SORT_TABLES: [&'static str; 1] = ["lineitem"]; + + /// Sort queries with different characteristics: + /// - Sort key with fixed length or variable length (VARCHAR) + /// - Sort key with different cardinality + /// - Different number of sort keys + /// - Different number of payload columns (thin: 1 additional column other + /// than sort keys; wide: all columns except sort keys) + /// + /// DataSet is `lineitem` table in TPCH dataset (16 columns, 6M rows for + /// scale factor 1.0, cardinality is counted from SF1 dataset) + /// + /// Key Columns: + /// - Column `l_linenumber`, type: `INTEGER`, cardinality: 7 + /// - Column `l_suppkey`, type: `BIGINT`, cardinality: 10k + /// - Column `l_orderkey`, type: `BIGINT`, cardinality: 1.5M + /// - Column `l_comment`, type: `VARCHAR`, cardinality: 4.5M (len is ~26 chars) + /// + /// Payload Columns: + /// - Thin variant: `l_partkey` column with `BIGINT` type (1 column) + /// - Wide variant: all columns except for possible key columns (12 columns) + const SORT_QUERIES: [&'static str; 10] = [ + // Q1: 1 sort key (type: INTEGER, cardinality: 7) + 1 payload column + r#" + SELECT l_linenumber, l_partkey + FROM lineitem + ORDER BY l_linenumber + "#, + // Q2: 1 sort key (type: BIGINT, cardinality: 1.5M) + 1 payload column + r#" + SELECT l_orderkey, l_partkey + FROM lineitem + ORDER BY l_orderkey + "#, + // Q3: 1 sort key (type: VARCHAR, cardinality: 4.5M) + 1 payload column + r#" + SELECT l_comment, l_partkey + FROM lineitem + ORDER BY l_comment + "#, + // Q4: 2 sort keys {(BIGINT, 1.5M), (INTEGER, 7)} + 1 payload column + r#" + SELECT l_orderkey, l_linenumber, l_partkey + FROM lineitem + ORDER BY l_orderkey, l_linenumber + "#, + // Q5: 3 sort keys {(INTEGER, 7), (BIGINT, 10k), (BIGINT, 1.5M)} + no payload column + r#" + SELECT l_linenumber, l_suppkey, l_orderkey + FROM lineitem + ORDER BY l_linenumber, l_suppkey, l_orderkey + "#, + // Q6: 3 sort keys {(INTEGER, 7), (BIGINT, 10k), (BIGINT, 1.5M)} + 1 payload column + r#" + SELECT l_linenumber, l_suppkey, l_orderkey, l_partkey + FROM lineitem + ORDER BY l_linenumber, l_suppkey, l_orderkey + "#, + // Q7: 3 sort keys {(INTEGER, 7), (BIGINT, 10k), (BIGINT, 1.5M)} + 12 all other columns + r#" + SELECT l_linenumber, l_suppkey, l_orderkey, + l_partkey, l_quantity, l_extendedprice, l_discount, l_tax, + l_returnflag, l_linestatus, l_shipdate, l_commitdate, + l_receiptdate, l_shipinstruct, l_shipmode + FROM lineitem + ORDER BY l_linenumber, l_suppkey, l_orderkey + "#, + // Q8: 4 sort keys {(BIGINT, 1.5M), (BIGINT, 10k), (INTEGER, 7), (VARCHAR, 4.5M)} + no payload column + r#" + SELECT l_orderkey, l_suppkey, l_linenumber, l_comment + FROM lineitem + ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment + "#, + // Q9: 4 sort keys {(BIGINT, 1.5M), (BIGINT, 10k), (INTEGER, 7), (VARCHAR, 4.5M)} + 1 payload column + r#" + SELECT l_orderkey, l_suppkey, l_linenumber, l_comment, l_partkey + FROM lineitem + ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment + "#, + // Q10: 4 sort keys {(BIGINT, 1.5M), (BIGINT, 10k), (INTEGER, 7), (VARCHAR, 4.5M)} + 12 all other columns + r#" + SELECT l_orderkey, l_suppkey, l_linenumber, l_comment, + l_partkey, l_quantity, l_extendedprice, l_discount, l_tax, + l_returnflag, l_linestatus, l_shipdate, l_commitdate, + l_receiptdate, l_shipinstruct, l_shipmode + FROM lineitem + ORDER BY l_orderkey, l_suppkey, l_linenumber, l_comment + "#, + ]; + + /// If query is specified from command line, run only that query. + /// Otherwise, run all queries. + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => 1..=Self::SORT_QUERIES.len(), + }; + + for query_id in query_range { + benchmark_run.start_new_case(&format!("{query_id}")); + + let query_results = self.benchmark_query(query_id).await?; + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + + Ok(()) + } + + /// Benchmark query `query_id` in `SORT_QUERIES` + async fn benchmark_query(&self, query_id: usize) -> Result> { + let config = self.common.config(); + + let runtime_config = RuntimeConfig::new().build_arc()?; + let ctx = SessionContext::new_with_config_rt(config, runtime_config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_idx = query_id - 1; // 1-indexed -> 0-indexed + let sql = Self::SORT_QUERIES[query_idx]; + + let row_count = self.execute_query(&ctx, sql).await?; + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + println!( + "Q{query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Q{query_id} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::SORT_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(table, Arc::new(memtable))?; + } else { + ctx.register_table(table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query(&self, ctx: &SessionContext, sql: &str) -> Result { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + + let mut row_count = 0; + + let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; + while let Some(batch) = stream.next().await { + row_count += batch.unwrap().num_rows(); + } + + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + } + + Ok(row_count) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = config.infer_schema(&state).await?; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } +} From 6d8313ebc865f9bff007bfc04652f58b016cbc1b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 15 Nov 2024 15:51:12 -0600 Subject: [PATCH 07/45] fix docs of register_table to match implementation (#13438) I'm not sure that changing the implementation is possible at this point. We could call deregister_table but I fear that's not atomic. So we'd have to change the implementation of SchemaProvider, a breaking change. --- datafusion/core/src/execution/context/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 45dfe835880f8..5f01d41c31e73 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1428,8 +1428,8 @@ impl SessionContext { /// Registers a [`TableProvider`] as a table that can be /// referenced from SQL statements executed against this context. /// - /// Returns the [`TableProvider`] previously registered for this - /// reference, if any + /// If a table of the same name was already registered, returns "Table + /// already exists" error. pub fn register_table( &self, table_ref: impl Into, From a09814ab4ec2b8c0d9e4472a40c85e2b36e024e5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 15 Nov 2024 22:41:49 -0500 Subject: [PATCH 08/45] Minor: Remove MOVED file (#13442) --- datafusion/core/tests/sqllogictests/MOVED.md | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 datafusion/core/tests/sqllogictests/MOVED.md diff --git a/datafusion/core/tests/sqllogictests/MOVED.md b/datafusion/core/tests/sqllogictests/MOVED.md deleted file mode 100644 index dd70dab9d11f2..0000000000000 --- a/datafusion/core/tests/sqllogictests/MOVED.md +++ /dev/null @@ -1,20 +0,0 @@ - - -The SQL Logic Test code has moved to `datafusion/sqllogictest` From 06db9ed865dc48bb1c87ce60d85331d385ee0f17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alihan=20=C3=87elikcan?= Date: Sat, 16 Nov 2024 09:34:16 +0300 Subject: [PATCH 09/45] Deduplicate and standardize deserialization logic for streams (#13412) * Add BatchDeserializer * Fix formatting * Remove unused enum value * Update datafusion/core/src/datasource/file_format/mod.rs --------- Co-authored-by: Mehmet Ozan Kabak --- .../core/src/datasource/file_format/csv.rs | 235 +++++++++++++++++- .../core/src/datasource/file_format/json.rs | 134 +++++++++- .../core/src/datasource/file_format/mod.rs | 171 ++++++++++++- .../core/src/datasource/physical_plan/csv.rs | 41 +-- .../core/src/datasource/physical_plan/json.rs | 38 +-- 5 files changed, 547 insertions(+), 72 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index d59e2bf71d642..9f979ddf01e78 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -23,7 +23,10 @@ use std::fmt::{self, Debug}; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::{FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD}; +use super::{ + Decoder, DecoderDeserializer, FileFormat, FileFormatFactory, + DEFAULT_SCHEMA_INFER_MAX_RECORD, +}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ @@ -38,8 +41,8 @@ use crate::physical_plan::{ use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; -use arrow::datatypes::SchemaRef; -use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use arrow_schema::ArrowError; use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::{ @@ -293,6 +296,45 @@ impl CsvFormat { } } +#[derive(Debug)] +pub(crate) struct CsvDecoder { + inner: arrow::csv::reader::Decoder, +} + +impl CsvDecoder { + pub(crate) fn new(decoder: arrow::csv::reader::Decoder) -> Self { + Self { inner: decoder } + } +} + +impl Decoder for CsvDecoder { + fn decode(&mut self, buf: &[u8]) -> Result { + self.inner.decode(buf) + } + + fn flush(&mut self) -> Result, ArrowError> { + self.inner.flush() + } + + fn can_flush_early(&self) -> bool { + self.inner.capacity() == 0 + } +} + +impl Debug for CsvSerializer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CsvSerializer") + .field("header", &self.header) + .finish() + } +} + +impl From for DecoderDeserializer { + fn from(decoder: arrow::csv::reader::Decoder) -> Self { + DecoderDeserializer::new(CsvDecoder::new(decoder)) + } +} + #[async_trait] impl FileFormat for CsvFormat { fn as_any(&self) -> &dyn Any { @@ -692,23 +734,28 @@ impl DataSink for CsvSink { mod tests { use super::super::test_util::scan_format; use super::*; - use crate::arrow::util::pretty; use crate::assert_batches_eq; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::test_util::VariableStream; + use crate::datasource::file_format::{ + BatchDeserializer, DecoderDeserializer, DeserializerOutput, + }; use crate::datasource::listing::ListingOptions; + use crate::execution::session_state::SessionStateBuilder; use crate::physical_plan::collect; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::arrow_test_data; use arrow::compute::concat_batches; + use arrow::csv::ReaderBuilder; + use arrow::util::pretty::pretty_format_batches; + use arrow_array::{BooleanArray, Float64Array, Int32Array, StringArray}; use datafusion_common::cast::as_string_array; use datafusion_common::internal_err; use datafusion_common::stats::Precision; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::{col, lit}; - use crate::execution::session_state::SessionStateBuilder; use chrono::DateTime; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -1097,7 +1144,7 @@ mod tests { ) -> Result { let df = ctx.sql(&format!("EXPLAIN {sql}")).await?; let result = df.collect().await?; - let plan = format!("{}", &pretty::pretty_format_batches(&result)?); + let plan = format!("{}", &pretty_format_batches(&result)?); let re = Regex::new(r"CsvExec: file_groups=\{(\d+) group").unwrap(); @@ -1464,4 +1511,180 @@ mod tests { Ok(()) } + + #[rstest] + fn test_csv_deserializer_with_finish( + #[values(1, 5, 17)] batch_size: usize, + #[values(0, 5, 93)] line_count: usize, + ) -> Result<()> { + let schema = csv_schema(); + let generator = CsvBatchGenerator::new(batch_size, line_count); + let mut deserializer = csv_deserializer(batch_size, &schema); + + for data in generator { + deserializer.digest(data); + } + deserializer.finish(); + + let batch_count = line_count.div_ceil(batch_size); + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + for _ in 0..batch_count { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])?; + } + assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); + + let expected = csv_expected_batch(schema, line_count)?; + + assert_eq!( + expected.clone(), + all_batches.clone(), + "Expected:\n{}\nActual:\n{}", + pretty_format_batches(&[expected])?, + pretty_format_batches(&[all_batches])?, + ); + + Ok(()) + } + + #[rstest] + fn test_csv_deserializer_without_finish( + #[values(1, 5, 17)] batch_size: usize, + #[values(0, 5, 93)] line_count: usize, + ) -> Result<()> { + let schema = csv_schema(); + let generator = CsvBatchGenerator::new(batch_size, line_count); + let mut deserializer = csv_deserializer(batch_size, &schema); + + for data in generator { + deserializer.digest(data); + } + + let batch_count = line_count / batch_size; + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + for _ in 0..batch_count { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])?; + } + assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); + + let expected = csv_expected_batch(schema, batch_count * batch_size)?; + + assert_eq!( + expected.clone(), + all_batches.clone(), + "Expected:\n{}\nActual:\n{}", + pretty_format_batches(&[expected])?, + pretty_format_batches(&[all_batches])?, + ); + + Ok(()) + } + + struct CsvBatchGenerator { + batch_size: usize, + line_count: usize, + offset: usize, + } + + impl CsvBatchGenerator { + fn new(batch_size: usize, line_count: usize) -> Self { + Self { + batch_size, + line_count, + offset: 0, + } + } + } + + impl Iterator for CsvBatchGenerator { + type Item = Bytes; + + fn next(&mut self) -> Option { + // Return `batch_size` rows per batch: + let mut buffer = Vec::new(); + for _ in 0..self.batch_size { + if self.offset >= self.line_count { + break; + } + buffer.extend_from_slice(&csv_line(self.offset)); + self.offset += 1; + } + + (!buffer.is_empty()).then(|| buffer.into()) + } + } + + fn csv_expected_batch( + schema: SchemaRef, + line_count: usize, + ) -> Result { + let mut c1 = Vec::with_capacity(line_count); + let mut c2 = Vec::with_capacity(line_count); + let mut c3 = Vec::with_capacity(line_count); + let mut c4 = Vec::with_capacity(line_count); + + for i in 0..line_count { + let (int_value, float_value, bool_value, char_value) = csv_values(i); + c1.push(int_value); + c2.push(float_value); + c3.push(bool_value); + c4.push(char_value); + } + + let expected = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(c1)), + Arc::new(Float64Array::from(c2)), + Arc::new(BooleanArray::from(c3)), + Arc::new(StringArray::from(c4)), + ], + )?; + Ok(expected) + } + + fn csv_line(line_number: usize) -> Bytes { + let (int_value, float_value, bool_value, char_value) = csv_values(line_number); + format!( + "{},{},{},{}\n", + int_value, float_value, bool_value, char_value + ) + .into() + } + + fn csv_values(line_number: usize) -> (i32, f64, bool, String) { + let int_value = line_number as i32; + let float_value = line_number as f64; + let bool_value = line_number % 2 == 0; + let char_value = format!("{}-string", line_number); + (int_value, float_value, bool_value, char_value) + } + + fn csv_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Boolean, true), + Field::new("c4", DataType::Utf8, true), + ])) + } + + fn csv_deserializer( + batch_size: usize, + schema: &Arc, + ) -> impl BatchDeserializer { + let decoder = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .build_decoder(); + DecoderDeserializer::new(CsvDecoder::new(decoder)) + } } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 4f51dd5ae1f57..e97853e9e7d72 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -26,7 +26,8 @@ use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; use super::{ - FileFormat, FileFormatFactory, FileScanConfig, DEFAULT_SCHEMA_INFER_MAX_RECORD, + Decoder, DecoderDeserializer, FileFormat, FileFormatFactory, FileScanConfig, + DEFAULT_SCHEMA_INFER_MAX_RECORD, }; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; @@ -44,6 +45,7 @@ use arrow::datatypes::SchemaRef; use arrow::json; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; +use arrow_schema::ArrowError; use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; @@ -384,16 +386,53 @@ impl DataSink for JsonSink { } } +#[derive(Debug)] +pub(crate) struct JsonDecoder { + inner: json::reader::Decoder, +} + +impl JsonDecoder { + pub(crate) fn new(decoder: json::reader::Decoder) -> Self { + Self { inner: decoder } + } +} + +impl Decoder for JsonDecoder { + fn decode(&mut self, buf: &[u8]) -> Result { + self.inner.decode(buf) + } + + fn flush(&mut self) -> Result, ArrowError> { + self.inner.flush() + } + + fn can_flush_early(&self) -> bool { + false + } +} + +impl From for DecoderDeserializer { + fn from(decoder: json::reader::Decoder) -> Self { + DecoderDeserializer::new(JsonDecoder::new(decoder)) + } +} + #[cfg(test)] mod tests { use super::super::test_util::scan_format; use super::*; + use crate::datasource::file_format::{ + BatchDeserializer, DecoderDeserializer, DeserializerOutput, + }; use crate::execution::options::NdJsonReadOptions; use crate::physical_plan::collect; use crate::prelude::{SessionConfig, SessionContext}; use crate::test::object_store::local_unpartitioned_file; + use arrow::compute::concat_batches; + use arrow::json::ReaderBuilder; use arrow::util::pretty; + use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::stats::Precision; use datafusion_common::{assert_batches_eq, internal_err}; @@ -612,4 +651,97 @@ mod tests { Ok(()) } + + #[test] + fn test_json_deserializer_finish() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::Int64, true), + Field::new("c5", DataType::Int64, true), + ])); + let mut deserializer = json_deserializer(1, &schema)?; + + deserializer.digest(r#"{ "c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5 }"#.into()); + deserializer.digest(r#"{ "c1": 6, "c2": 7, "c3": 8, "c4": 9, "c5": 10 }"#.into()); + deserializer + .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); + deserializer.finish(); + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + for _ in 0..3 { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])? + } + assert_eq!(deserializer.next()?, DeserializerOutput::InputExhausted); + + let expected = [ + "+----+----+----+----+----+", + "| c1 | c2 | c3 | c4 | c5 |", + "+----+----+----+----+----+", + "| 1 | 2 | 3 | 4 | 5 |", + "| 6 | 7 | 8 | 9 | 10 |", + "| 11 | 12 | 13 | 14 | 15 |", + "+----+----+----+----+----+", + ]; + + assert_batches_eq!(expected, &[all_batches]); + + Ok(()) + } + + #[test] + fn test_json_deserializer_no_finish() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int64, true), + Field::new("c2", DataType::Int64, true), + Field::new("c3", DataType::Int64, true), + Field::new("c4", DataType::Int64, true), + Field::new("c5", DataType::Int64, true), + ])); + let mut deserializer = json_deserializer(1, &schema)?; + + deserializer.digest(r#"{ "c1": 1, "c2": 2, "c3": 3, "c4": 4, "c5": 5 }"#.into()); + deserializer.digest(r#"{ "c1": 6, "c2": 7, "c3": 8, "c4": 9, "c5": 10 }"#.into()); + deserializer + .digest(r#"{ "c1": 11, "c2": 12, "c3": 13, "c4": 14, "c5": 15 }"#.into()); + + let mut all_batches = RecordBatch::new_empty(schema.clone()); + // We get RequiresMoreData after 2 batches because of how json::Decoder works + for _ in 0..2 { + let output = deserializer.next()?; + let DeserializerOutput::RecordBatch(batch) = output else { + panic!("Expected RecordBatch, got {:?}", output); + }; + all_batches = concat_batches(&schema, &[all_batches, batch])? + } + assert_eq!(deserializer.next()?, DeserializerOutput::RequiresMoreData); + + let expected = [ + "+----+----+----+----+----+", + "| c1 | c2 | c3 | c4 | c5 |", + "+----+----+----+----+----+", + "| 1 | 2 | 3 | 4 | 5 |", + "| 6 | 7 | 8 | 9 | 10 |", + "+----+----+----+----+----+", + ]; + + assert_batches_eq!(expected, &[all_batches]); + + Ok(()) + } + + fn json_deserializer( + batch_size: usize, + schema: &Arc, + ) -> Result> { + let decoder = ReaderBuilder::new(schema.clone()) + .with_batch_size(batch_size) + .build_decoder()?; + Ok(DecoderDeserializer::new(JsonDecoder::new(decoder))) + } } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 5c9eb7f20ae25..eb2a85367f80d 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -32,9 +32,10 @@ pub mod parquet; pub mod write; use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Display}; +use std::collections::{HashMap, VecDeque}; +use std::fmt::{self, Debug, Display}; use std::sync::Arc; +use std::task::Poll; use crate::arrow::datatypes::SchemaRef; use crate::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -42,17 +43,20 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use arrow_schema::{DataType, Field, FieldRef, Schema}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt}; use datafusion_expr::Expr; use datafusion_physical_expr::PhysicalExpr; use async_trait::async_trait; +use bytes::{Buf, Bytes}; use datafusion_physical_expr_common::sort_expr::LexRequirement; use file_compression_type::FileCompressionType; +use futures::stream::BoxStream; +use futures::{ready, Stream, StreamExt}; use object_store::{ObjectMeta, ObjectStore}; -use std::fmt::Debug; /// Factory for creating [`FileFormat`] instances based on session and command level options /// @@ -168,6 +172,165 @@ pub enum FilePushdownSupport { Supported, } +/// Possible outputs of a [`BatchDeserializer`]. +#[derive(Debug, PartialEq)] +pub enum DeserializerOutput { + /// A successfully deserialized [`RecordBatch`]. + RecordBatch(RecordBatch), + /// The deserializer requires more data to make progress. + RequiresMoreData, + /// The input data has been exhausted. + InputExhausted, +} + +/// Trait defining a scheme for deserializing byte streams into structured data. +/// Implementors of this trait are responsible for converting raw bytes into +/// `RecordBatch` objects. +pub trait BatchDeserializer: Send + Debug { + /// Feeds a message for deserialization, updating the internal state of + /// this `BatchDeserializer`. Note that one can call this function multiple + /// times before calling `next`, which will queue multiple messages for + /// deserialization. Returns the number of bytes consumed. + fn digest(&mut self, message: T) -> usize; + + /// Attempts to deserialize any pending messages and returns a + /// `DeserializerOutput` to indicate progress. + fn next(&mut self) -> Result; + + /// Informs the deserializer that no more messages will be provided for + /// deserialization. + fn finish(&mut self); +} + +/// A general interface for decoders such as [`arrow::json::reader::Decoder`] and +/// [`arrow::csv::reader::Decoder`]. Defines an interface similar to +/// [`Decoder::decode`] and [`Decoder::flush`] methods, but also includes +/// a method to check if the decoder can flush early. Intended to be used in +/// conjunction with [`DecoderDeserializer`]. +/// +/// [`arrow::json::reader::Decoder`]: ::arrow::json::reader::Decoder +/// [`arrow::csv::reader::Decoder`]: ::arrow::csv::reader::Decoder +/// [`Decoder::decode`]: ::arrow::json::reader::Decoder::decode +/// [`Decoder::flush`]: ::arrow::json::reader::Decoder::flush +pub(crate) trait Decoder: Send + Debug { + /// See [`arrow::json::reader::Decoder::decode`]. + /// + /// [`arrow::json::reader::Decoder::decode`]: ::arrow::json::reader::Decoder::decode + fn decode(&mut self, buf: &[u8]) -> Result; + + /// See [`arrow::json::reader::Decoder::flush`]. + /// + /// [`arrow::json::reader::Decoder::flush`]: ::arrow::json::reader::Decoder::flush + fn flush(&mut self) -> Result, ArrowError>; + + /// Whether the decoder can flush early in its current state. + fn can_flush_early(&self) -> bool; +} + +impl Debug for DecoderDeserializer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Deserializer") + .field("buffered_queue", &self.buffered_queue) + .field("finalized", &self.finalized) + .finish() + } +} + +impl BatchDeserializer for DecoderDeserializer { + fn digest(&mut self, message: Bytes) -> usize { + if message.is_empty() { + return 0; + } + + let consumed = message.len(); + self.buffered_queue.push_back(message); + consumed + } + + fn next(&mut self) -> Result { + while let Some(buffered) = self.buffered_queue.front_mut() { + let decoded = self.decoder.decode(buffered)?; + buffered.advance(decoded); + + if buffered.is_empty() { + self.buffered_queue.pop_front(); + } + + // Flush when the stream ends or batch size is reached + // Certain implementations can flush early + if decoded == 0 || self.decoder.can_flush_early() { + return match self.decoder.flush() { + Ok(Some(batch)) => Ok(DeserializerOutput::RecordBatch(batch)), + Ok(None) => continue, + Err(e) => Err(e), + }; + } + } + if self.finalized { + Ok(DeserializerOutput::InputExhausted) + } else { + Ok(DeserializerOutput::RequiresMoreData) + } + } + + fn finish(&mut self) { + self.finalized = true; + // Ensure the decoder is flushed: + self.buffered_queue.push_back(Bytes::new()); + } +} + +/// A generic, decoder-based deserialization scheme for processing encoded data. +/// +/// This struct is responsible for converting a stream of bytes, which represent +/// encoded data, into a stream of `RecordBatch` objects, following the specified +/// schema and formatting options. It also handles any buffering necessary to satisfy +/// the `Decoder` interface. +pub(crate) struct DecoderDeserializer { + /// The underlying decoder used for deserialization + pub(crate) decoder: T, + /// The buffer used to store the remaining bytes to be decoded + pub(crate) buffered_queue: VecDeque, + /// Whether the input stream has been fully consumed + pub(crate) finalized: bool, +} + +impl DecoderDeserializer { + /// Creates a new `DecoderDeserializer` with the provided decoder. + pub(crate) fn new(decoder: T) -> Self { + DecoderDeserializer { + decoder, + buffered_queue: VecDeque::new(), + finalized: false, + } + } +} + +/// Deserializes a stream of bytes into a stream of [`RecordBatch`] objects using the +/// provided deserializer. +/// +/// Returns a boxed stream of `Result`. The stream yields [`RecordBatch`] +/// objects as they are produced by the deserializer, or an [`ArrowError`] if an error +/// occurs while polling the input or deserializing. +pub(crate) fn deserialize_stream<'a>( + mut input: impl Stream> + Unpin + Send + 'a, + mut deserializer: impl BatchDeserializer + 'a, +) -> BoxStream<'a, Result> { + futures::stream::poll_fn(move |cx| loop { + match ready!(input.poll_next_unpin(cx)).transpose()? { + Some(b) => _ = deserializer.digest(b), + None => deserializer.finish(), + }; + + return match deserializer.next()? { + DeserializerOutput::RecordBatch(rb) => Poll::Ready(Some(Ok(rb))), + DeserializerOutput::InputExhausted => Poll::Ready(None), + DeserializerOutput::RequiresMoreData => continue, + }; + }) + .boxed() +} + /// A container of [FileFormatFactory] which also implements [FileType]. /// This enables converting a dyn FileFormat to a dyn FileType. /// The former trait is a superset of the latter trait, which includes execution time diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 1679acf30342a..0c41f69c76916 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -24,6 +24,7 @@ use std::task::Poll; use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer}; use crate::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, @@ -42,8 +43,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use bytes::{Buf, Bytes}; -use futures::{ready, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; @@ -651,36 +651,14 @@ impl FileOpener for CsvOpener { Ok(futures::stream::iter(config.open(decoder)?).boxed()) } GetResultPayload::Stream(s) => { - let mut decoder = config.builder().build_decoder(); + let decoder = config.builder().build_decoder(); let s = s.map_err(DataFusionError::from); - let mut input = - file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffered = Bytes::new(); - - let s = futures::stream::poll_fn(move |cx| { - loop { - if buffered.is_empty() { - match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => buffered = b, - Some(Err(e)) => { - return Poll::Ready(Some(Err(e.into()))) - } - None => {} - }; - } - let decoded = match decoder.decode(buffered.as_ref()) { - // Note: the decoder needs to be called with an empty - // array to delimt the final record - Ok(0) => break, - Ok(decoded) => decoded, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - buffered.advance(decoded); - } - - Poll::Ready(decoder.flush().transpose()) - }); - Ok(s.boxed()) + let input = file_compression_type.convert_stream(s.boxed())?.fuse(); + + Ok(deserialize_stream( + input, + DecoderDeserializer::from(decoder), + )) } } })) @@ -753,6 +731,7 @@ mod tests { use crate::{scalar::ScalarValue, test_util::aggr_test_schema}; use arrow::datatypes::*; + use bytes::Bytes; use datafusion_common::test_util::arrow_test_data; use datafusion_common::config::CsvOptions; diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 7b0a605aed05e..c86f8fbd262ff 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -24,6 +24,7 @@ use std::task::Poll; use super::{calculate_range, FileGroupPartitioner, FileScanConfig, RangeCalculation}; use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::{deserialize_stream, DecoderDeserializer}; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::datasource::physical_plan::file_stream::{ FileOpenFuture, FileOpener, FileStream, @@ -41,8 +42,7 @@ use arrow::{datatypes::SchemaRef, json}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use bytes::{Buf, Bytes}; -use futures::{ready, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; @@ -312,37 +312,15 @@ impl FileOpener for JsonOpener { GetResultPayload::Stream(s) => { let s = s.map_err(DataFusionError::from); - let mut decoder = ReaderBuilder::new(schema) + let decoder = ReaderBuilder::new(schema) .with_batch_size(batch_size) .build_decoder()?; - let mut input = - file_compression_type.convert_stream(s.boxed())?.fuse(); - let mut buffer = Bytes::new(); - - let s = futures::stream::poll_fn(move |cx| { - loop { - if buffer.is_empty() { - match ready!(input.poll_next_unpin(cx)) { - Some(Ok(b)) => buffer = b, - Some(Err(e)) => { - return Poll::Ready(Some(Err(e.into()))) - } - None => {} - }; - } - - let decoded = match decoder.decode(buffer.as_ref()) { - Ok(0) => break, - Ok(decoded) => decoded, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - - buffer.advance(decoded); - } + let input = file_compression_type.convert_stream(s.boxed())?.fuse(); - Poll::Ready(decoder.flush().transpose()) - }); - Ok(s.boxed()) + Ok(deserialize_stream( + input, + DecoderDeserializer::from(decoder), + )) } } })) From 73507c307487708deb321e1ba4e0d302084ca27e Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 16 Nov 2024 18:03:10 +0800 Subject: [PATCH 10/45] organize ExternalSorter fields (#13447) --- datafusion/physical-plan/src/sorts/sort.rs | 43 +++++++++++++++------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 9f7bd6b28a2e9..e9f17ddebabc9 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -203,39 +203,54 @@ impl ExternalSorterMetrics { /// in_mem_batches /// ``` struct ExternalSorter { - /// schema of the output (and the input) + // ======================================================================== + // PROPERTIES: + // Fields that define the sorter's configuration and remain constant + // ======================================================================== + /// Schema of the output (and the input) schema: SchemaRef, + /// Sort expressions + expr: Arc<[PhysicalSortExpr]>, + /// If Some, the maximum number of output rows that will be produced + fetch: Option, + /// The target number of rows for output batches + batch_size: usize, + /// If the in size of buffered memory batches is below this size, + /// the data will be concatenated and sorted in place rather than + /// sort/merged. + sort_in_place_threshold_bytes: usize, + + // ======================================================================== + // STATE BUFFERS: + // Fields that hold intermediate data during sorting + // ======================================================================== /// Potentially unsorted in memory buffer in_mem_batches: Vec, /// if `Self::in_mem_batches` are sorted in_mem_batches_sorted: bool, + /// If data has previously been spilled, the locations of the /// spill files (in Arrow IPC format) spills: Vec, - /// Sort expressions - expr: Arc<[PhysicalSortExpr]>, + + // ======================================================================== + // EXECUTION RESOURCES: + // Fields related to managing execution resources and monitoring performance. + // ======================================================================== /// Runtime metrics metrics: ExternalSorterMetrics, - /// If Some, the maximum number of output rows that will be - /// produced. - fetch: Option, + /// A handle to the runtime to get spill files + runtime: Arc, /// Reservation for in_mem_batches reservation: MemoryReservation, + /// Reservation for the merging of in-memory batches. If the sort /// might spill, `sort_spill_reservation_bytes` will be /// pre-reserved to ensure there is some space for this sort/merge. merge_reservation: MemoryReservation, - /// A handle to the runtime to get spill files - runtime: Arc, - /// The target number of rows for output batches - batch_size: usize, /// How much memory to reserve for performing in-memory sort/merges /// prior to spilling. sort_spill_reservation_bytes: usize, - /// If the in size of buffered memory batches is below this size, - /// the data will be concatenated and sorted in place rather than - /// sort/merged. - sort_in_place_threshold_bytes: usize, } impl ExternalSorter { From 61fa572ff97444b5d70ea93209e3e440497afded Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 17 Nov 2024 05:51:12 -0500 Subject: [PATCH 11/45] feat: Add `stringview` support to `encode` and `decode` and `bit_length` (#13332) * add stringview * add tests * remove utf8view * remove array_to_string changes * remove use --- datafusion/functions/src/core/named_struct.rs | 4 +-- datafusion/functions/src/encoding/inner.rs | 20 ++++++++---- .../sqllogictest/test_files/encoding.slt | 31 +++++++++++++++++++ datafusion/sqllogictest/test_files/expr.slt | 5 +++ 4 files changed, 52 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index d53dd2277f844..0211ed3fe691b 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -24,9 +24,9 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; -/// put values in a struct array. +/// Put values in a struct array. fn named_struct_expr(args: &[ColumnarValue]) -> Result { - // do not accept 0 arguments. + // Do not accept 0 arguments. if args.is_empty() { return exec_err!( "named_struct requires at least one pair of arguments, got 0 instead" diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 4f91879f94db7..0649c7cbb5c0f 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -108,7 +108,7 @@ impl ScalarUDFImpl for EncodeFunc { } match arg_types[0] { - DataType::Utf8 | DataType::Binary | DataType::Null => { + DataType::Utf8 | DataType::Utf8View | DataType::Binary | DataType::Null => { Ok(vec![DataType::Utf8; 2]) } DataType::LargeUtf8 | DataType::LargeBinary => { @@ -195,7 +195,7 @@ impl ScalarUDFImpl for DecodeFunc { } match arg_types[0] { - DataType::Utf8 | DataType::Binary | DataType::Null => { + DataType::Utf8 | DataType::Utf8View | DataType::Binary | DataType::Null => { Ok(vec![DataType::Binary, DataType::Utf8]) } DataType::LargeUtf8 | DataType::LargeBinary => { @@ -224,6 +224,7 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result match a.data_type() { DataType::Utf8 => encoding.encode_utf8_array::(a.as_ref()), DataType::LargeUtf8 => encoding.encode_utf8_array::(a.as_ref()), + DataType::Utf8View => encoding.encode_utf8_array::(a.as_ref()), DataType::Binary => encoding.encode_binary_array::(a.as_ref()), DataType::LargeBinary => encoding.encode_binary_array::(a.as_ref()), other => exec_err!( @@ -237,6 +238,9 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result Ok(encoding .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Utf8View(a) => { + Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) + } ScalarValue::Binary(a) => Ok( encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) ), @@ -255,6 +259,7 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result match a.data_type() { DataType::Utf8 => encoding.decode_utf8_array::(a.as_ref()), DataType::LargeUtf8 => encoding.decode_utf8_array::(a.as_ref()), + DataType::Utf8View => encoding.decode_utf8_array::(a.as_ref()), DataType::Binary => encoding.decode_binary_array::(a.as_ref()), DataType::LargeBinary => encoding.decode_binary_array::(a.as_ref()), other => exec_err!( @@ -268,6 +273,9 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result encoding .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())), + ScalarValue::Utf8View(a) => { + encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) + } ScalarValue::Binary(a) => { encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) } @@ -512,7 +520,7 @@ impl FromStr for Encoding { } } -/// Encodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Encodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn encode(args: &[ColumnarValue]) -> Result { @@ -524,7 +532,7 @@ fn encode(args: &[ColumnarValue]) -> Result { } let encoding = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { method.parse::() } _ => not_impl_err!( @@ -538,7 +546,7 @@ fn encode(args: &[ColumnarValue]) -> Result { encode_process(&args[0], encoding) } -/// Decodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. +/// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. fn decode(args: &[ColumnarValue]) -> Result { @@ -550,7 +558,7 @@ fn decode(args: &[ColumnarValue]) -> Result { } let encoding = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { method.parse::() } _ => not_impl_err!( diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index 68bdf78115aac..fc22cc8bf7a7e 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -71,3 +71,34 @@ select to_hex(num) from test ORDER BY num; 0 1 2 + +# test for Utf8View support for encode +statement ok +CREATE TABLE test_source AS VALUES + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R'); + +statement ok +CREATE TABLE test_utf8view AS +select + arrow_cast(column1, 'Utf8View') AS column1_utf8view, + arrow_cast(column2, 'Utf8View') AS column2_utf8view +FROM test_source; + +query TTTTTT +SELECT + column1_utf8view, + encode(column1_utf8view, 'base64') AS column1_base64, + encode(column1_utf8view, 'hex') AS column1_hex, + + column2_utf8view, + encode(column2_utf8view, 'base64') AS column2_base64, + encode(column2_utf8view, 'hex') AS column2_hex +FROM test_utf8view; +---- +Andrew QW5kcmV3 416e64726577 X WA 58 +Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 Xiangpeng WGlhbmdwZW5n 5869616e6770656e67 +Raphael UmFwaGFlbA 5261706861656c R Ug 52 +NULL NULL NULL R Ug 52 \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 15bf771c65271..31467072dd3e4 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -364,6 +364,11 @@ SELECT bit_length(NULL) ---- NULL +query I +SELECT bit_length(arrow_cast('jonathan', 'Utf8View')); +---- +64 + query T SELECT btrim(' xyxtrimyyx ', NULL) ---- From b75563b3a96f291577d3ed22e79d663988e5d269 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 17 Nov 2024 18:51:47 +0800 Subject: [PATCH 12/45] Support unparsing Array plan to SQL string (#13418) * unparse construct and access * add sql e2e roundtrip * remove unused tests * fix test and clippy * fix clippy --- datafusion/functions-nested/src/planner.rs | 1 - datafusion/sql/src/unparser/expr.rs | 47 +++++++++++++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 15 ++++++- 3 files changed, 58 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 9ae2fa781d87e..1929b8222a1b6 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -133,7 +133,6 @@ impl ExprPlanner for NestedFunctionPlanner { #[derive(Debug)] pub struct FieldAccessPlanner; - impl ExprPlanner for FieldAccessPlanner { fn plan_field_access( &self, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8f6ffa51f76a3..8664abd6543ef 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,8 @@ use datafusion_expr::expr::Unnest; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, - TimezoneInfo, UnaryOperator, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + Subscript, TimezoneInfo, UnaryOperator, }; use std::sync::Arc; use std::vec; @@ -476,6 +476,19 @@ impl Unparser<'_> { &self, func_name: &str, args: &[Expr], + ) -> Result { + match func_name { + "make_array" => self.make_array_to_sql(args), + "array_element" => self.array_element_to_sql(args), + // TODO: support for the construct and access functions of the `map` and `struct` types + _ => self.scalar_function_to_sql_internal(func_name, args), + } + } + + fn scalar_function_to_sql_internal( + &self, + func_name: &str, + args: &[Expr], ) -> Result { let args = self.function_args_to_sql(args)?; Ok(ast::Expr::Function(Function { @@ -496,6 +509,29 @@ impl Unparser<'_> { })) } + fn make_array_to_sql(&self, args: &[Expr]) -> Result { + let args = args + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + Ok(ast::Expr::Array(Array { + elem: args, + named: false, + })) + } + + fn array_element_to_sql(&self, args: &[Expr]) -> Result { + if args.len() != 2 { + return internal_err!("array_element must have exactly 2 arguments"); + } + let array = self.expr_to_sql(&args[0])?; + let index = self.expr_to_sql(&args[1])?; + Ok(ast::Expr::Subscript { + expr: Box::new(array), + subscript: Box::new(Subscript::Index { index }), + }) + } + pub fn sort_to_sql(&self, sort: &Sort) -> Result { let Sort { expr, @@ -1485,6 +1521,7 @@ mod tests { use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_functions_nested::expr_fn::{array_element, make_array}; use datafusion_functions_window::row_number::row_number_udwf; use crate::unparser::dialect::{ @@ -1889,6 +1926,12 @@ mod tests { }), r#"UNNEST("table".array_col)"#, ), + (make_array(vec![lit(1), lit(2), lit(3)]), "[1, 2, 3]"), + (array_element(col("array_col"), lit(1)), "array_col[1]"), + ( + array_element(make_array(vec![lit(1), lit(2), lit(3)]), lit(1)), + "[1, 2, 3][1]", + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 94e420066d8b8..4f43d7333dd14 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -40,6 +40,8 @@ use datafusion_expr::builder::{ table_scan_with_filter_and_fetch, table_scan_with_filters, }; use datafusion_functions::core::planner::CoreFunctionPlanner; +use datafusion_functions_nested::extract::array_element_udf; +use datafusion_functions_nested::planner::{FieldAccessPlanner, NestedFunctionPlanner}; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -182,6 +184,11 @@ fn roundtrip_statement() -> Result<()> { SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total FROM person GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#, + "SELECT ARRAY[1, 2, 3]", + "SELECT ARRAY[1, 2, 3][1]", + "SELECT [1, 2, 3]", + "SELECT [1, 2, 3][1]", + "SELECT left[1] FROM array" ]; // For each test sql string, we transform as follows: @@ -195,10 +202,14 @@ fn roundtrip_statement() -> Result<()> { .try_with_sql(query)? .parse_statement()?; let state = MockSessionState::default() + .with_scalar_function(make_array_udf()) + .with_scalar_function(array_element_udf()) .with_aggregate_function(sum_udaf()) .with_aggregate_function(count_udaf()) .with_aggregate_function(max_udaf()) - .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(NestedFunctionPlanner)) + .with_expr_planner(Arc::new(FieldAccessPlanner)); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); @@ -1239,6 +1250,6 @@ fn test_unnest_to_sql() { sql_round_trip( GenericDialect {}, r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#, - r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#, + r#"SELECT UNNEST([1, 2, 2, 5, NULL]) AS u1"#, ); } From 2db90feaa1f19c9d547f7f7e527e0e2ffc426435 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sun, 17 Nov 2024 12:33:23 +0100 Subject: [PATCH 13/45] Fix test query results even for quick test execution (#13453) --- datafusion/sqllogictest/test_files/insert_to_external.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 35decd728eed7..edfc2ee75bd75 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -610,7 +610,7 @@ select count(distinct e) from test_column_defaults # Expect all rows to be true as now() was inserted into the table query B rowsort -select e < now() from test_column_defaults +select e <= now() from test_column_defaults ---- true true From e2376c422b7913fb00918f6d3a0b61c4c94792d4 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <33904309+akurmustafa@users.noreply.github.com> Date: Sun, 17 Nov 2024 03:33:59 -0800 Subject: [PATCH 14/45] [MINOR]: fix min max accumulator nan bug (#13432) * fix nan bug * Swap order of the operations --- datafusion/functions-aggregate/src/min_max.rs | 17 +++++++++++++---- datafusion/sqllogictest/test_files/group_by.slt | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index b497953bc5913..618edd343f7d8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -44,6 +44,7 @@ use datafusion_common::{ use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_physical_expr::expressions; +use std::cmp::Ordering; use std::fmt::Debug; use arrow::datatypes::i256; @@ -113,8 +114,12 @@ macro_rules! primitive_max_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { - if *cur < new { - *cur = new + match (new).partial_cmp(cur) { + Some(Ordering::Greater) | None => { + // new is Greater or None + *cur = new + } + _ => {} } }) // Initialize each accumulator to $NATIVE::MIN @@ -132,8 +137,12 @@ macro_rules! primitive_min_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { - if *cur > new { - *cur = new + match (new).partial_cmp(cur) { + Some(Ordering::Less) | None => { + // new is Less or NaN + *cur = new + } + _ => {} } }) // Initialize each accumulator to $NATIVE::MAX diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 4b90ddf2ea5f1..391f84836871c 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5271,3 +5271,19 @@ drop view t statement ok drop table source; + +# Test whether min, max accumulator produces NaN result when input is NaN. +# See https://github.com/apache/datafusion/issues/13415 for rationale +statement ok +CREATE TABLE input_table ( + "row" integer, + "x" double precision +); + +statement ok +INSERT INTO input_table VALUES (1, 'NaN'); + +query RR +SELECT max(input_table.x), min(input_table.x) from input_table GROUP BY input_table."row"; +---- +NaN NaN From a8921016d4894c6f5dc2689d91f63a8694bfabd6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sun, 17 Nov 2024 12:35:32 +0100 Subject: [PATCH 15/45] Evaluate cheaper condition first in join selection and physical planner (#13435) * Evaluate cheaper condition first in join selection * Evaluate cheaper condition first in physical planner --- .../core/src/physical_optimizer/join_selection.rs | 10 +++++----- datafusion/core/src/physical_planner.rs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 9b2402c6bb875..fdb2920300dd4 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -382,8 +382,8 @@ fn try_collect_left( match (left_can_collect, right_can_collect) { (true, true) => { - if should_swap_join_order(&**left, &**right)? - && supports_swap(*hash_join.join_type()) + if supports_swap(*hash_join.join_type()) + && should_swap_join_order(&**left, &**right)? { Ok(Some(swap_hash_join(hash_join, PartitionMode::CollectLeft)?)) } else { @@ -423,7 +423,7 @@ fn try_collect_left( fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right)? && supports_swap(*hash_join.join_type()) + if supports_swap(*hash_join.join_type()) && should_swap_join_order(&**left, &**right)? { swap_hash_join(hash_join, PartitionMode::Partitioned) } else { @@ -468,8 +468,8 @@ fn statistical_join_selection_subrule( PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right(); - if should_swap_join_order(&**left, &**right)? - && supports_swap(*hash_join.join_type()) + if supports_swap(*hash_join.join_type()) + && should_swap_join_order(&**left, &**right)? { swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? } else { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 26f6b12908a7c..69c2ccc04aaa8 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -656,8 +656,8 @@ impl DefaultPhysicalPlanner { let logical_input_schema = input.as_ref().schema(); let physical_input_schema_from_logical = logical_input_schema.inner(); - if &physical_input_schema != physical_input_schema_from_logical - && !options.execution.skip_physical_aggregate_schema_check + if !options.execution.skip_physical_aggregate_schema_check + && &physical_input_schema != physical_input_schema_from_logical { return internal_err!("Physical input schema should be the same as the one converted from logical input schema."); } From cd013c734c110c474776849393e08236d7908299 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sun, 17 Nov 2024 03:35:47 -0800 Subject: [PATCH 16/45] Fix duckdb & sqlite character_length scalar unparsing (#13428) * Fix duckdb & sqlite character_length scalar unparsing (#59) * Fix duckdb & sqlite character_length scalar unparsing * Add comments * Update CharacterLengthStyle::SQLStandard to CharacterLengthExtractStyle::CharacterLength * Fix clippy error --- datafusion/sql/src/unparser/dialect.rs | 93 +++++++++++++++++++++++--- datafusion/sql/src/unparser/expr.rs | 31 ++++++++- datafusion/sql/src/unparser/utils.rs | 21 +++++- 3 files changed, 133 insertions(+), 12 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 87ed1b8f41409..fbaa402e703ca 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -27,7 +27,7 @@ use sqlparser::{ use datafusion_common::Result; -use super::{utils::date_part_to_sql, Unparser}; +use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; /// `Dialect` to use for Unparsing /// @@ -80,6 +80,11 @@ pub trait Dialect: Send + Sync { DateFieldExtractStyle::DatePart } + /// The character length extraction style to use: `CharacterLengthStyle` + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::CharacterLength + } + /// The SQL type to use for Arrow Int64 unparsing /// Most dialects use BigInt, but some, like MySQL, require SIGNED fn int64_cast_dtype(&self) -> ast::DataType { @@ -176,6 +181,17 @@ pub enum DateFieldExtractStyle { Strftime, } +/// `CharacterLengthStyle` to use for unparsing +/// +/// Different DBMSs uses different names for function calculating the number of characters in the string +/// `Length` style uses length(x) +/// `SQLStandard` style uses character_length(x) +#[derive(Clone, Copy, PartialEq)] +pub enum CharacterLengthStyle { + Length, + CharacterLength, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -271,6 +287,35 @@ impl PostgreSqlDialect { } } +pub struct DuckDBDialect {} + +impl Dialect for DuckDBDialect { + fn identifier_quote_style(&self, _: &str) -> Option { + Some('"') + } + + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "character_length" { + return character_length_to_sql( + unparser, + self.character_length_style(), + args, + ); + } + + Ok(None) + } +} + pub struct MySqlDialect {} impl Dialect for MySqlDialect { @@ -347,6 +392,10 @@ impl Dialect for SqliteDialect { ast::DataType::Text } + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + fn supports_column_alias_in_table_alias(&self) -> bool { false } @@ -357,11 +406,15 @@ impl Dialect for SqliteDialect { func_name: &str, args: &[Expr], ) -> Result> { - if func_name == "date_part" { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + match func_name { + "date_part" => { + date_part_to_sql(unparser, self.date_field_extract_style(), args) + } + "character_length" => { + character_length_to_sql(unparser, self.character_length_style(), args) + } + _ => Ok(None), } - - Ok(None) } } @@ -374,6 +427,7 @@ pub struct CustomDialect { utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, + character_length_style: CharacterLengthStyle, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -395,6 +449,7 @@ impl Default for CustomDialect { utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, + character_length_style: CharacterLengthStyle::CharacterLength, int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -454,6 +509,10 @@ impl Dialect for CustomDialect { self.date_field_extract_style } + fn character_length_style(&self) -> CharacterLengthStyle { + self.character_length_style + } + fn int64_cast_dtype(&self) -> ast::DataType { self.int64_cast_dtype.clone() } @@ -488,11 +547,15 @@ impl Dialect for CustomDialect { func_name: &str, args: &[Expr], ) -> Result> { - if func_name == "date_part" { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + match func_name { + "date_part" => { + date_part_to_sql(unparser, self.date_field_extract_style(), args) + } + "character_length" => { + character_length_to_sql(unparser, self.character_length_style(), args) + } + _ => Ok(None), } - - Ok(None) } fn requires_derived_table_alias(&self) -> bool { @@ -527,6 +590,7 @@ pub struct CustomDialectBuilder { utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, + character_length_style: CharacterLengthStyle, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -554,6 +618,7 @@ impl CustomDialectBuilder { utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, + character_length_style: CharacterLengthStyle::CharacterLength, int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -578,6 +643,7 @@ impl CustomDialectBuilder { utf8_cast_dtype: self.utf8_cast_dtype, large_utf8_cast_dtype: self.large_utf8_cast_dtype, date_field_extract_style: self.date_field_extract_style, + character_length_style: self.character_length_style, int64_cast_dtype: self.int64_cast_dtype, int32_cast_dtype: self.int32_cast_dtype, timestamp_cast_dtype: self.timestamp_cast_dtype, @@ -620,6 +686,15 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific character_length_style listed in `CharacterLengthStyle` + pub fn with_character_length_style( + mut self, + character_length_style: CharacterLengthStyle, + ) -> Self { + self.character_length_style = character_length_style; + self + } + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { self.float64_ast_dtype = float64_ast_dtype; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8664abd6543ef..a6f7c4fd1100f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1525,8 +1525,8 @@ mod tests { use datafusion_functions_window::row_number::row_number_udwf; use crate::unparser::dialect::{ - CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, - PostgreSqlDialect, + CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, + Dialect, PostgreSqlDialect, }; use super::*; @@ -2050,6 +2050,33 @@ mod tests { Ok(()) } + #[test] + fn test_character_length_scalar_to_expr() { + let tests = [ + (CharacterLengthStyle::Length, "length(x)"), + (CharacterLengthStyle::CharacterLength, "character_length(x)"), + ]; + + for (style, expected) in tests { + let dialect = CustomDialectBuilder::new() + .with_character_length_style(style) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = ScalarUDF::new_from_impl( + datafusion_functions::unicode::character_length::CharacterLengthFunc::new( + ), + ) + .call(vec![col("x")]); + + let ast = unparser.expr_to_sql(&expr).expect("to be unparsed"); + + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + #[test] fn test_interval_scalar_to_expr() { let tests = [ diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 284956cef195e..d0f80da83d63f 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -28,7 +28,10 @@ use datafusion_expr::{ }; use sqlparser::ast; -use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; +use super::{ + dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, + rewrite::TableAliasRewriter, Unparser, +}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -445,3 +448,19 @@ pub(crate) fn date_part_to_sql( Ok(None) } + +pub(crate) fn character_length_to_sql( + unparser: &Unparser, + style: CharacterLengthStyle, + character_length_args: &[Expr], +) -> Result> { + let func_name = match style { + CharacterLengthStyle::CharacterLength => "character_length", + CharacterLengthStyle::Length => "length", + }; + + Ok(Some(unparser.scalar_function_to_sql( + func_name, + character_length_args, + )?)) +} From e3c4541a1f6034fb47d28be7a82edb3f59001475 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sun, 17 Nov 2024 12:36:07 +0100 Subject: [PATCH 17/45] chore: remove unnecessary test helpers (#13317) * chore: remove unnecessary test helpers * cargo fmt --- .../simplify_expressions/expr_simplifier.rs | 158 +++++++----------- 1 file changed, 59 insertions(+), 99 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index d8ca246bb6359..6564e722eaf89 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2865,15 +2865,12 @@ mod tests { ); // single character - assert_change( - regex_match(col("c1"), lit("x")), - like(col("c1"), lit("%x%")), - ); + assert_change(regex_match(col("c1"), lit("x")), col("c1").like(lit("%x%"))); // single word assert_change( regex_match(col("c1"), lit("foo")), - like(col("c1"), lit("%foo%")), + col("c1").like(lit("%foo%")), ); // regular expressions that match an exact literal @@ -2963,48 +2960,53 @@ mod tests { // regular expressions that match a partial literal assert_change( regex_match(col("c1"), lit("^foo")), - like(col("c1"), lit("foo%")), + col("c1").like(lit("foo%")), ); assert_change( regex_match(col("c1"), lit("foo$")), - like(col("c1"), lit("%foo")), + col("c1").like(lit("%foo")), ); assert_change( regex_match(col("c1"), lit("^foo|bar$")), - like(col("c1"), lit("foo%")).or(like(col("c1"), lit("%bar"))), + col("c1").like(lit("foo%")).or(col("c1").like(lit("%bar"))), ); // OR-chain assert_change( regex_match(col("c1"), lit("foo|bar|baz")), - like(col("c1"), lit("%foo%")) - .or(like(col("c1"), lit("%bar%"))) - .or(like(col("c1"), lit("%baz%"))), + col("c1") + .like(lit("%foo%")) + .or(col("c1").like(lit("%bar%"))) + .or(col("c1").like(lit("%baz%"))), ); assert_change( regex_match(col("c1"), lit("foo|x|baz")), - like(col("c1"), lit("%foo%")) - .or(like(col("c1"), lit("%x%"))) - .or(like(col("c1"), lit("%baz%"))), + col("c1") + .like(lit("%foo%")) + .or(col("c1").like(lit("%x%"))) + .or(col("c1").like(lit("%baz%"))), ); assert_change( regex_not_match(col("c1"), lit("foo|bar|baz")), - not_like(col("c1"), lit("%foo%")) - .and(not_like(col("c1"), lit("%bar%"))) - .and(not_like(col("c1"), lit("%baz%"))), + col("c1") + .not_like(lit("%foo%")) + .and(col("c1").not_like(lit("%bar%"))) + .and(col("c1").not_like(lit("%baz%"))), ); // both anchored expressions (translated to equality) and unanchored assert_change( regex_match(col("c1"), lit("foo|^x$|baz")), - like(col("c1"), lit("%foo%")) + col("c1") + .like(lit("%foo%")) .or(col("c1").eq(lit("x"))) - .or(like(col("c1"), lit("%baz%"))), + .or(col("c1").like(lit("%baz%"))), ); assert_change( regex_not_match(col("c1"), lit("foo|^bar$|baz")), - not_like(col("c1"), lit("%foo%")) + col("c1") + .not_like(lit("%foo%")) .and(col("c1").not_eq(lit("bar"))) - .and(not_like(col("c1"), lit("%baz%"))), + .and(col("c1").not_like(lit("%baz%"))), ); // Too many patterns (MAX_REGEX_ALTERNATIONS_EXPANSION) assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc"))); @@ -3054,46 +3056,6 @@ mod tests { }) } - fn like(expr: Expr, pattern: impl Into) -> Expr { - Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern.into()), - escape_char: None, - case_insensitive: false, - }) - } - - fn not_like(expr: Expr, pattern: impl Into) -> Expr { - Expr::Like(Like { - negated: true, - expr: Box::new(expr), - pattern: Box::new(pattern.into()), - escape_char: None, - case_insensitive: false, - }) - } - - fn ilike(expr: Expr, pattern: impl Into) -> Expr { - Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern.into()), - escape_char: None, - case_insensitive: true, - }) - } - - fn not_ilike(expr: Expr, pattern: impl Into) -> Expr { - Expr::Like(Like { - negated: true, - expr: Box::new(expr), - pattern: Box::new(pattern.into()), - escape_char: None, - case_insensitive: true, - }) - } - // ------------------------------ // ----- Simplifier tests ------- // ------------------------------ @@ -3703,119 +3665,117 @@ mod tests { let null = lit(ScalarValue::Utf8(None)); // expr [NOT] [I]LIKE NULL - let expr = like(col("c1"), null.clone()); + let expr = col("c1").like(null.clone()); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_like(col("c1"), null.clone()); + let expr = col("c1").not_like(null.clone()); assert_eq!(simplify(expr), lit_bool_null()); - let expr = ilike(col("c1"), null.clone()); + let expr = col("c1").ilike(null.clone()); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_ilike(col("c1"), null.clone()); + let expr = col("c1").not_ilike(null.clone()); assert_eq!(simplify(expr), lit_bool_null()); // expr [NOT] [I]LIKE '%' - let expr = like(col("c1"), lit("%")); + let expr = col("c1").like(lit("%")); assert_eq!(simplify(expr), if_not_null(col("c1"), true)); - let expr = not_like(col("c1"), lit("%")); + let expr = col("c1").not_like(lit("%")); assert_eq!(simplify(expr), if_not_null(col("c1"), false)); - let expr = ilike(col("c1"), lit("%")); + let expr = col("c1").ilike(lit("%")); assert_eq!(simplify(expr), if_not_null(col("c1"), true)); - let expr = not_ilike(col("c1"), lit("%")); + let expr = col("c1").not_ilike(lit("%")); assert_eq!(simplify(expr), if_not_null(col("c1"), false)); // expr [NOT] [I]LIKE '%%' - let expr = like(col("c1"), lit("%%")); + let expr = col("c1").like(lit("%%")); assert_eq!(simplify(expr), if_not_null(col("c1"), true)); - let expr = not_like(col("c1"), lit("%%")); + let expr = col("c1").not_like(lit("%%")); assert_eq!(simplify(expr), if_not_null(col("c1"), false)); - let expr = ilike(col("c1"), lit("%%")); + let expr = col("c1").ilike(lit("%%")); assert_eq!(simplify(expr), if_not_null(col("c1"), true)); - let expr = not_ilike(col("c1"), lit("%%")); + let expr = col("c1").not_ilike(lit("%%")); assert_eq!(simplify(expr), if_not_null(col("c1"), false)); // not_null_expr [NOT] [I]LIKE '%' - let expr = like(col("c1_non_null"), lit("%")); + let expr = col("c1_non_null").like(lit("%")); assert_eq!(simplify(expr), lit(true)); - let expr = not_like(col("c1_non_null"), lit("%")); + let expr = col("c1_non_null").not_like(lit("%")); assert_eq!(simplify(expr), lit(false)); - let expr = ilike(col("c1_non_null"), lit("%")); + let expr = col("c1_non_null").ilike(lit("%")); assert_eq!(simplify(expr), lit(true)); - let expr = not_ilike(col("c1_non_null"), lit("%")); + let expr = col("c1_non_null").not_ilike(lit("%")); assert_eq!(simplify(expr), lit(false)); // not_null_expr [NOT] [I]LIKE '%%' - let expr = like(col("c1_non_null"), lit("%%")); + let expr = col("c1_non_null").like(lit("%%")); assert_eq!(simplify(expr), lit(true)); - let expr = not_like(col("c1_non_null"), lit("%%")); + let expr = col("c1_non_null").not_like(lit("%%")); assert_eq!(simplify(expr), lit(false)); - let expr = ilike(col("c1_non_null"), lit("%%")); + let expr = col("c1_non_null").ilike(lit("%%")); assert_eq!(simplify(expr), lit(true)); - let expr = not_ilike(col("c1_non_null"), lit("%%")); + let expr = col("c1_non_null").not_ilike(lit("%%")); assert_eq!(simplify(expr), lit(false)); // null_constant [NOT] [I]LIKE '%' - let expr = like(null.clone(), lit("%")); + let expr = null.clone().like(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_like(null.clone(), lit("%")); + let expr = null.clone().not_like(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = ilike(null.clone(), lit("%")); + let expr = null.clone().ilike(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_ilike(null, lit("%")); + let expr = null.clone().not_ilike(lit("%")); assert_eq!(simplify(expr), lit_bool_null()); // null_constant [NOT] [I]LIKE '%%' - let null = lit(ScalarValue::Utf8(None)); - let expr = like(null.clone(), lit("%%")); + let expr = null.clone().like(lit("%%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_like(null.clone(), lit("%%")); + let expr = null.clone().not_like(lit("%%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = ilike(null.clone(), lit("%%")); + let expr = null.clone().ilike(lit("%%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_ilike(null, lit("%%")); + let expr = null.clone().not_ilike(lit("%%")); assert_eq!(simplify(expr), lit_bool_null()); // null_constant [NOT] [I]LIKE 'a%' - let null = lit(ScalarValue::Utf8(None)); - let expr = like(null.clone(), lit("a%")); + let expr = null.clone().like(lit("a%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_like(null.clone(), lit("a%")); + let expr = null.clone().not_like(lit("a%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = ilike(null.clone(), lit("a%")); + let expr = null.clone().ilike(lit("a%")); assert_eq!(simplify(expr), lit_bool_null()); - let expr = not_ilike(null, lit("a%")); + let expr = null.clone().not_ilike(lit("a%")); assert_eq!(simplify(expr), lit_bool_null()); // expr [NOT] [I]LIKE with pattern without wildcards - let expr = like(col("c1"), lit("a")); + let expr = col("c1").like(lit("a")); assert_eq!(simplify(expr), col("c1").eq(lit("a"))); - let expr = not_like(col("c1"), lit("a")); + let expr = col("c1").not_like(lit("a")); assert_eq!(simplify(expr), col("c1").not_eq(lit("a"))); - let expr = like(col("c1"), lit("a_")); + let expr = col("c1").like(lit("a_")); assert_eq!(simplify(expr), col("c1").like(lit("a_"))); - let expr = not_like(col("c1"), lit("a_")); + let expr = col("c1").not_like(lit("a_")); assert_eq!(simplify(expr), col("c1").not_like(lit("a_"))); } From 97045ec02871042930251a5f1278daf8bf722eb9 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sun, 17 Nov 2024 12:36:31 +0100 Subject: [PATCH 18/45] Produce informative error on physical schema mismatch (#13434) Include details that can help understand the problem. --- datafusion/core/src/physical_planner.rs | 36 ++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 69c2ccc04aaa8..44537c951f945 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -659,7 +659,41 @@ impl DefaultPhysicalPlanner { if !options.execution.skip_physical_aggregate_schema_check && &physical_input_schema != physical_input_schema_from_logical { - return internal_err!("Physical input schema should be the same as the one converted from logical input schema."); + let mut differences = Vec::new(); + if physical_input_schema.fields().len() + != physical_input_schema_from_logical.fields().len() + { + differences.push(format!( + "Different number of fields: (physical) {} vs (logical) {}", + physical_input_schema.fields().len(), + physical_input_schema_from_logical.fields().len() + )); + } + for (i, (physical_field, logical_field)) in physical_input_schema + .fields() + .iter() + .zip(physical_input_schema_from_logical.fields()) + .enumerate() + { + if physical_field.name() != logical_field.name() { + differences.push(format!( + "field name at index {}: (physical) {} vs (logical) {}", + i, + physical_field.name(), + logical_field.name() + )); + } + if physical_field.data_type() != logical_field.data_type() { + differences.push(format!("field data type at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.data_type(), logical_field.data_type())); + } + if physical_field.is_nullable() != logical_field.is_nullable() { + differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); + } + } + return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences + .iter() + .map(|s| format!("\n\t- {}", s)) + .join("")); } let groups = self.create_grouping_physical_expr( From 6b0570bb8f48b0ccd605d0b5156cc98749c9c913 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Sun, 17 Nov 2024 12:37:09 +0100 Subject: [PATCH 19/45] Fix invalid swap for LeftMark nested loops join (#13426) --- .../src/physical_optimizer/join_selection.rs | 4 +- datafusion/sqllogictest/test_files/join.slt | 86 +++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index fdb2920300dd4..511aaacf3ef10 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -494,7 +494,9 @@ fn statistical_join_selection_subrule( } else if let Some(nl_join) = plan.as_any().downcast_ref::() { let left = nl_join.left(); let right = nl_join.right(); - if should_swap_join_order(&**left, &**right)? { + if supports_swap(*nl_join.join_type()) + && should_swap_join_order(&**left, &**right)? + { swap_nl_join(nl_join).map(Some)? } else { None diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 39f903a587143..1feacc5ebe53e 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -1226,3 +1226,89 @@ select t1.v1 from t1 join t1 using(v1) cross join (select struct('foo' as v1) as statement ok drop table t1; + + +statement ok +CREATE TABLE t1(a INTEGER, b INTEGER, c INTEGER, d INTEGER, e INTEGER); + +statement ok +INSERT INTO t1(e,c,b,d,a) VALUES(103,102,100,101,104); + +statement ok +INSERT INTO t1(a,c,d,e,b) VALUES(107,106,108,109,105); + +statement ok +INSERT INTO t1(d,c,e,a,b) VALUES(116,119,117,115,118); + +statement ok +INSERT INTO t1(c,d,b,e,a) VALUES(123,122,124,120,121); + +statement ok +INSERT INTO t1(b,a,e,d,c) VALUES(145,149,146,148,147); + +statement ok +INSERT INTO t1(b,c,a,d,e) VALUES(151,150,153,154,152); + +statement ok +INSERT INTO t1(c,b,a,d,e) VALUES(161,160,163,164,162); + +statement ok +INSERT INTO t1(b,d,a,e,c) VALUES(167,169,168,165,166); + +statement ok +INSERT INTO t1(d,b,c,e,a) VALUES(171,170,172,173,174); + +statement ok +INSERT INTO t1(e,c,a,d,b) VALUES(177,176,179,178,175); + +statement ok +INSERT INTO t1(b,e,a,d,c) VALUES(181,180,182,183,184); + +statement ok +INSERT INTO t1(c,e,a,b,d) VALUES(208,209,205,206,207); + +statement ok +INSERT INTO t1(c,e,a,d,b) VALUES(214,210,213,212,211); + +statement ok +INSERT INTO t1(b,c,a,d,e) VALUES(218,215,216,217,219); + +statement ok +INSERT INTO t1(e,c,b,a,d) VALUES(242,244,240,243,241); + +statement ok +INSERT INTO t1(e,d,c,b,a) VALUES(246,248,247,249,245); + +# Regression test for https://github.com/apache/datafusion/issues/13425 +query IIIIII +SELECT a+b*2, + a+b*2+c*3+d*4, + CASE WHEN ac OR e Date: Mon, 18 Nov 2024 13:32:24 +0100 Subject: [PATCH 20/45] Fix redundant data copying in unnest (#13441) * Fix redundant data copying in unnest * Add test * fix typo --- datafusion/physical-plan/src/unnest.rs | 62 ++++++++++++++++--- datafusion/sqllogictest/test_files/unnest.slt | 6 ++ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 06288a1f70419..0615e6738a1fc 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -36,7 +36,7 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{cast, is_not_null, kernels, sum}; use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::{Int64Array, Scalar, StructArray}; +use arrow_array::{new_null_array, Int64Array, Scalar, StructArray}; use arrow_ord::cmp::lt; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, HashMap, HashSet, Result, UnnestOptions, @@ -453,16 +453,36 @@ fn list_unnest_at_level( // Create the take indices array for other columns let take_indices = create_take_indicies(unnested_length, total_length); - - // Dimension of arrays in batch is untouched, but the values are repeated - // as the side effect of unnesting - let ret = repeat_arrs_from_indices(batch, &take_indices)?; unnested_temp_arrays .into_iter() .zip(list_unnest_specs.iter()) .for_each(|(flatten_arr, unnesting)| { temp_unnested_arrs.insert(*unnesting, flatten_arr); }); + + let repeat_mask: Vec = batch + .iter() + .enumerate() + .map(|(i, _)| { + // Check if the column is needed in future levels (levels below the current one) + let needed_in_future_levels = list_type_unnests.iter().any(|unnesting| { + unnesting.index_in_input_schema == i && unnesting.depth < level_to_unnest + }); + + // Check if the column is involved in unnesting at any level + let is_involved_in_unnesting = list_type_unnests + .iter() + .any(|unnesting| unnesting.index_in_input_schema == i); + + // Repeat columns needed in future levels or not unnested. + needed_in_future_levels || !is_involved_in_unnesting + }) + .collect(); + + // Dimension of arrays in batch is untouched, but the values are repeated + // as the side effect of unnesting + let ret = repeat_arrs_from_indices(batch, &take_indices, &repeat_mask)?; + Ok((ret, total_length)) } struct UnnestingResult { @@ -859,8 +879,11 @@ fn create_take_indicies( builder.finish() } -/// Create the batch given an arrays and a `indices` array -/// that is used by the take kernel to copy values. +/// Create a batch of arrays based on an input `batch` and a `indices` array. +/// The `indices` array is used by the take kernel to repeat values in the arrays +/// that are marked with `true` in the `repeat_mask`. Arrays marked with `false` +/// in the `repeat_mask` will be replaced with arrays filled with nulls of the +/// appropriate length. /// /// For example if we have the following batch: /// @@ -890,14 +913,35 @@ fn create_take_indicies( /// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd' /// ``` /// +/// The `repeat_mask` determines whether an array's values are repeated or replaced with nulls. +/// For example, if the `repeat_mask` is: +/// +/// ```ignore +/// [true, false] +/// ``` +/// +/// The final batch will look like: +/// +/// ```ignore +/// c1: 1, null, 2, 3, 4, null, 5, 6 // Repeated using `indices` +/// c2: null, null, null, null, null, null, null, null // Replaced with nulls +/// fn repeat_arrs_from_indices( batch: &[ArrayRef], indices: &PrimitiveArray, + repeat_mask: &[bool], ) -> Result>> { batch .iter() - .map(|arr| Ok(kernels::take::take(arr, indices, None)?)) - .collect::>() + .zip(repeat_mask.iter()) + .map(|(arr, &repeat)| { + if repeat { + Ok(kernels::take::take(arr, indices, None)?) + } else { + Ok(new_null_array(arr.data_type(), arr.len())) + } + }) + .collect() } #[cfg(test)] diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 8ebed5b25ca92..2e1b8b87cc429 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -853,3 +853,9 @@ select unnest(u.column5), j.* except(column2, column3) from unnest_table u join 1 2 1 3 4 2 NULL NULL 4 + +## Issue: https://github.com/apache/datafusion/issues/13237 +query I +select count(*) from (select unnest(range(0, 100000)) id) t inner join (select unnest(range(0, 100000)) id) t1 on t.id = t1.id; +---- +100000 From e03aa1271042d082906c4df46e1ccc32707e937d Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Mon, 18 Nov 2024 06:28:00 -0700 Subject: [PATCH 21/45] Add docs (#13454) --- datafusion/catalog/src/table.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index ca3a2bef882e2..d771930de25de 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -247,6 +247,9 @@ pub trait TableProvider: Debug + Sync + Send { } /// Get statistics for this table, if available + /// Although not presently used in mainline DataFusion, this allows implementation specific + /// behavior for downstream repositories, in conjunction with specialized optimizer rules to + /// perform operations such as re-ordering of joins. fn statistics(&self) -> Option { None } From 22c1f54411a02009629b3a76a43bd4343add045d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:02:27 -0500 Subject: [PATCH 22/45] Update sqllogictest requirement from 0.22.0 to 0.23.0 (#13464) Updates the requirements on [sqllogictest](https://github.com/risinglightdb/sqllogictest-rs) to permit the latest version. - [Release notes](https://github.com/risinglightdb/sqllogictest-rs/releases) - [Changelog](https://github.com/risinglightdb/sqllogictest-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/risinglightdb/sqllogictest-rs/compare/v0.22.0...v0.23.0) --- updated-dependencies: - dependency-name: sqllogictest dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/sqllogictest/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 81682558d0a94..ed2b9c49715e3 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -51,7 +51,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.22.0" +sqllogictest = "0.23.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = "2.0.0" From 498bcb9e76f7d15b197ca071b0bb28643d38d4b5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Nov 2024 15:32:55 -0500 Subject: [PATCH 23/45] Improve documentation (and ASCII art) about streaming execution, and thread pools (#13423) * Improve documentation about streaming and threadpools * Update datafusion/core/src/lib.rs Co-authored-by: Jonah Gao * Apply suggestions from code review Co-authored-by: Jonah Gao * Correct push/pull, add link to DuckDB, update paper link * Add note about spawn blocking --------- Co-authored-by: Jonah Gao --- datafusion/core/src/lib.rs | 189 ++++++++++++++++++++++++++++++++++--- 1 file changed, 175 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index b2df32a62e441..b58ef66d4cd2b 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -382,14 +382,14 @@ //! //! Calling [`execute`] produces 1 or more partitions of data, //! as a [`SendableRecordBatchStream`], which implements a pull based execution -//! API. Calling `.next().await` will incrementally compute and return the next +//! API. Calling [`next()`]`.await` will incrementally compute and return the next //! [`RecordBatch`]. Balanced parallelism is achieved using [Volcano style] //! "Exchange" operations implemented by [`RepartitionExec`]. //! //! While some recent research such as [Morsel-Driven Parallelism] describes challenges //! with the pull style Volcano execution model on NUMA architectures, in practice DataFusion achieves -//! similar scalability as systems that use morsel driven approach such as DuckDB. -//! See the [DataFusion paper submitted to SIGMOD] for more details. +//! similar scalability as systems that use push driven schedulers [such as DuckDB]. +//! See the [DataFusion paper in SIGMOD 2024] for more details. //! //! [`execute`]: physical_plan::ExecutionPlan::execute //! [`SendableRecordBatchStream`]: crate::physical_plan::SendableRecordBatchStream @@ -403,22 +403,183 @@ //! [`RepartitionExec`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/repartition/struct.RepartitionExec.html //! [Volcano style]: https://w6113.github.io/files/papers/volcanoparallelism-89.pdf //! [Morsel-Driven Parallelism]: https://db.in.tum.de/~leis/papers/morsels.pdf -//! [DataFusion paper submitted SIGMOD]: https://github.com/apache/datafusion/files/13874720/DataFusion_Query_Engine___SIGMOD_2024.pdf +//! [DataFusion paper in SIGMOD 2024]: https://github.com/apache/datafusion/files/15149988/DataFusion_Query_Engine___SIGMOD_2024-FINAL-mk4.pdf +//! [such as DuckDB]: https://github.com/duckdb/duckdb/issues/1583 //! [implementors of `ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#implementors //! -//! ## Thread Scheduling +//! ## Streaming Execution //! -//! DataFusion incrementally computes output from a [`SendableRecordBatchStream`] -//! with `target_partitions` threads. Parallelism is implementing using multiple -//! [Tokio] [`task`]s, which are executed by threads managed by a tokio Runtime. -//! While tokio is most commonly used -//! for asynchronous network I/O, its combination of an efficient, work-stealing -//! scheduler, first class compiler support for automatic continuation generation, -//! and exceptional performance makes it a compelling choice for CPU intensive -//! applications as well. This is explained in more detail in [Using Rustlang’s Async Tokio -//! Runtime for CPU-Bound Tasks]. +//! DataFusion is a "streaming" query engine which means `ExecutionPlan`s incrementally +//! read from their input(s) and compute output one [`RecordBatch`] at a time +//! by continually polling [`SendableRecordBatchStream`]s. Output and +//! intermediate `RecordBatch`s each have approximately `batch_size` rows, +//! which amortizes per-batch overhead of execution. +//! +//! Note that certain operations, sometimes called "pipeline breakers", +//! (for example full sorts or hash aggregations) are fundamentally non streaming and +//! must read their input fully before producing **any** output. As much as possible, +//! other operators read a single [`RecordBatch`] from their input to produce a +//! single `RecordBatch` as output. +//! +//! For example, given this SQL query: +//! +//! ```sql +//! SELECT date_trunc('month', time) FROM data WHERE id IN (10,20,30); +//! ``` +//! +//! The diagram below shows the call sequence when a consumer calls [`next()`] to +//! get the next `RecordBatch` of output. While it is possible that some +//! steps run on different threads, typically tokio will use the same thread +//! that called `next()` to read from the input, apply the filter, and +//! return the results without interleaving any other operations. This results +//! in excellent cache locality as the same CPU core that produces the data often +//! consumes it immediately as well. +//! +//! ```text +//! +//! Step 3: FilterExec calls next() Step 2: ProjectionExec calls +//! on input Stream next() on input Stream +//! ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +//! │ Step 1: Consumer +//! ▼ ▼ │ calls next() +//! ┏━━━━━━━━━━━━━━┓ ┏━━━━━┻━━━━━━━━━━━━━┓ ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ +//! ┃ ┃ ┃ ┃ ┃ ◀ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! ┃ DataSource ┃ ┃ ┃ ┃ ┃ +//! ┃ (e.g. ┃ ┃ FilterExec ┃ ┃ ProjectionExec ┃ +//! ┃ ParquetExec) ┃ ┃id IN (10, 20, 30) ┃ ┃date_bin('month', time) ┃ +//! ┃ ┃ ┃ ┃ ┃ ┣ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▶ +//! ┃ ┃ ┃ ┃ ┃ ┃ +//! ┗━━━━━━━━━━━━━━┛ ┗━━━━━━━━━━━┳━━━━━━━┛ ┗━━━━━━━━━━━━━━━━━━━━━━━━┛ +//! │ ▲ ▲ Step 6: ProjectionExec +//! │ │ │ computes date_trunc into a +//! └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ new RecordBatch returned +//! ┌─────────────────────┐ ┌─────────────┐ from client +//! │ RecordBatch │ │ RecordBatch │ +//! └─────────────────────┘ └─────────────┘ +//! +//! Step 4: DataSource returns a Step 5: FilterExec returns a new +//! single RecordBatch RecordBatch with only matching rows +//! ``` +//! +//! [`next()`]: futures::StreamExt::next +//! +//! ## Thread Scheduling, CPU / IO Thread Pools, and [Tokio] [`Runtime`]s +//! +//! DataFusion automatically runs each plan with multiple CPU cores using +//! a [Tokio] [`Runtime`] as a thread pool. While tokio is most commonly used +//! for asynchronous network I/O, the combination of an efficient, work-stealing +//! scheduler and first class compiler support for automatic continuation +//! generation (`async`), also makes it a compelling choice for CPU intensive +//! applications as explained in the [Using Rustlang’s Async Tokio +//! Runtime for CPU-Bound Tasks] blog. +//! +//! The number of cores used is determined by the `target_partitions` +//! configuration setting, which defaults to the number of CPU cores. +//! During execution, DataFusion creates this many distinct `async` [`Stream`]s and +//! this many distinct [Tokio] [`task`]s, which drive the `Stream`s +//! using threads managed by the `Runtime`. Many DataFusion `Stream`s perform +//! CPU intensive processing. +//! +//! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s +//! to perform network I/O using standard Rust `async` during execution. +//! However, this design also makes it very easy to mix CPU intensive and latency +//! sensitive I/O work on the same thread pool ([`Runtime`]). +//! Using the same (default) `Runtime` is convenient, and often works well for +//! initial development and processing local files, but it can lead to problems +//! under load and/or when reading from network sources such as AWS S3. +//! +//! If your system does not fully utilize either the CPU or network bandwidth +//! during execution, or you see significantly higher tail (e.g. p99) latencies +//! responding to network requests, **it is likely you need to use a different +//! `Runtime` for CPU intensive DataFusion plans**. This effect can be especially +//! pronounced when running several queries concurrently. +//! +//! As shown in the following figure, using the same `Runtime` for both CPU +//! intensive processing and network requests can introduce significant +//! delays in responding to those network requests. Delays in processing network +//! requests can and does lead network flow control to throttle the available +//! bandwidth in response. +//! +//! ```text +//! Legend +//! +//! ┏━━━━━━┓ +//! Processing network request ┃ ┃ CPU bound work +//! is delayed due to processing ┗━━━━━━┛ +//! CPU bound work ┌─┐ +//! │ │ Network request +//! ││ └─┘ processing +//! +//! ││ +//! ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! │ │ +//! +//! ▼ ▼ +//! ┌─────────────┐ ┌─┐┌─┐┏━━━━━━━━━━━━━━━━━━━┓┏━━━━━━━━━━━━━━━━━━━┓┌─┐ +//! │ │thread 1 │ ││ │┃ Decoding ┃┃ Filtering ┃│ │ +//! │ │ └─┘└─┘┗━━━━━━━━━━━━━━━━━━━┛┗━━━━━━━━━━━━━━━━━━━┛└─┘ +//! │ │ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │Tokio Runtime│thread 2 ┃ Decoding ┃ Filtering ┃ Decoding ┃ ... +//! │(thread pool)│ ┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! │ │ ... ... +//! │ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓┌─┐ ┏━━━━━━━━━━━━━━┓ +//! │ │thread N ┃ Decoding ┃ Filtering ┃│ │ ┃ Decoding ┃ +//! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┛└─┘ ┗━━━━━━━━━━━━━━┛ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//! ``` +//! +//! The bottleneck resulting from network throttling can be avoided +//! by using separate [`Runtime`]s for the different types of work, as shown +//! in the diagram below. +//! +//! ```text +//! A separate thread pool processes network Legend +//! requests, reducing the latency for +//! processing each request ┏━━━━━━┓ +//! ┃ ┃ CPU bound work +//! │ ┗━━━━━━┛ +//! │ ┌─┐ +//! ┌ ─ ─ ─ ─ ┘ │ │ Network request +//! ┌ ─ ─ ─ ┘ └─┘ processing +//! │ +//! ▼ ▼ +//! ┌─────────────┐ ┌─┐┌─┐┌─┐ +//! │ │thread 1 │ ││ ││ │ +//! │ │ └─┘└─┘└─┘ +//! │Tokio Runtime│ ... +//! │(thread pool)│thread 2 +//! │ │ +//! │"IO Runtime" │ ... +//! │ │ ┌─┐ +//! │ │thread N │ │ +//! └─────────────┘ └─┘ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//! +//! ┌─────────────┐ ┏━━━━━━━━━━━━━━━━━━━┓┏━━━━━━━━━━━━━━━━━━━┓ +//! │ │thread 1 ┃ Decoding ┃┃ Filtering ┃ +//! │ │ ┗━━━━━━━━━━━━━━━━━━━┛┗━━━━━━━━━━━━━━━━━━━┛ +//! │Tokio Runtime│ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │(thread pool)│thread 2 ┃ Decoding ┃ Filtering ┃ Decoding ┃ ... +//! │ │ ┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! │ CPU Runtime │ ... ... +//! │ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │ │thread N ┃ Decoding ┃ Filtering ┃ Decoding ┃ +//! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//!``` +//! +//! Note that DataFusion does not use [`tokio::task::spawn_blocking`] for +//! CPU-bounded work, because `spawn_blocking` is designed for blocking **IO**, +//! not designed CPU bound tasks. Among other challenges, spawned blocking +//! tasks can't yield waiting for input (can't call `await`) so they +//! can't be used to limit the number of concurrent CPU bound tasks or +//! keep the processing pipeline to the same core. //! //! [Tokio]: https://tokio.rs +//! [`Runtime`]: tokio::runtime::Runtime //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ //! From 900552cbdac3ee242178e4dce41339fbabcf1384 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 18 Nov 2024 12:34:27 -0800 Subject: [PATCH 24/45] Fix Binary & Binary View Unparsing (#13427) * Skip casting to binary when inner expr is value (#60) * Skip casting to binary when inner expr is value * Update datafusion/sql/src/unparser/expr.rs Co-authored-by: Jack Eadie --------- Co-authored-by: Jack Eadie * Fix binary view cast (#63) * fix * Fix clippy error --------- Co-authored-by: Jack Eadie --- datafusion/sql/src/unparser/expr.rs | 78 ++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 19 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a6f7c4fd1100f..3bf4ae304721c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -180,25 +180,7 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql_inner(expr)?; - match data_type { - DataType::Dictionary(_, _) => match inner_expr { - // Dictionary values don't need to be cast to other types when rewritten back to sql - ast::Expr::Value(_) => Ok(inner_expr), - _ => Ok(ast::Expr::Cast { - kind: ast::CastKind::Cast, - expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, - format: None, - }), - }, - _ => Ok(ast::Expr::Cast { - kind: ast::CastKind::Cast, - expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, - format: None, - }), - } + Ok(self.cast_to_sql(expr, data_type)?) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), @@ -901,6 +883,31 @@ impl Unparser<'_> { }) } + // Explicit type cast on ast::Expr::Value is not needed by underlying engine for certain types + // For example: CAST(Utf8("binary_value") AS Binary) and CAST(Utf8("dictionary_value") AS Dictionary) + fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) -> Result { + let inner_expr = self.expr_to_sql_inner(expr)?; + match inner_expr { + ast::Expr::Value(_) => match data_type { + DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView => { + Ok(inner_expr) + } + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + }, + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + } + } + /// DataFusion ScalarValues sometimes require a ast::Expr to construct. /// For example ScalarValue::Date32(d) corresponds to the ast::Expr CAST('datestr' as DATE) fn scalar_to_sql(&self, v: &ScalarValue) -> Result { @@ -2237,6 +2244,39 @@ mod tests { } } + #[test] + fn test_cast_value_to_binary_expr() { + let tests = [ + ( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "blah".to_string(), + )))), + data_type: DataType::Binary, + }), + "'blah'", + ), + ( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "blah".to_string(), + )))), + data_type: DataType::BinaryView, + }), + "'blah'", + ), + ]; + for (value, expected) in tests { + let dialect = CustomDialectBuilder::new().build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + #[test] fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { let default_dialect = CustomDialectBuilder::default().build(); From adcf90f18c55e560c40215c8856999eb6cab9e1e Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 19 Nov 2024 06:06:29 +0900 Subject: [PATCH 25/45] Support Utf8View in Unparser `expr_to_sql` (#13462) * Support Utf8View in Unparser expr_to_sql * Add another test * Update expr.rs Co-authored-by: Sherin Jacob * Fix import * feedback * Add null/is_not_null test --------- Co-authored-by: Sherin Jacob --- datafusion/sql/src/unparser/expr.rs | 52 ++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3bf4ae304721c..f1f28258f9bd6 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1458,9 +1458,7 @@ impl Unparser<'_> { } DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), - DataType::Utf8View => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } + DataType::Utf8View => Ok(self.dialect.utf8_cast_dtype()), DataType::List(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1520,7 +1518,7 @@ mod tests { use datafusion_common::TableReference; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - case, col, cube, exists, grouping_set, interval_datetime_lit, + case, cast, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, @@ -2540,4 +2538,50 @@ mod tests { } Ok(()) } + + #[test] + fn test_utf8_view_to_sql() -> Result<()> { + let dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + let unparser = Unparser::new(&dialect); + + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?; + + assert_eq!(ast_dtype, ast::DataType::Char(None)); + + let expr = cast(col("a"), DataType::Utf8View); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = r#"CAST(a AS CHAR)"#.to_string(); + + assert_eq!(actual, expected); + + let expr = col("a").eq(lit(ScalarValue::Utf8View(Some("hello".to_string())))); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = r#"(a = 'hello')"#.to_string(); + + assert_eq!(actual, expected); + + let expr = col("a").is_not_null(); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + let expected = r#"a IS NOT NULL"#.to_string(); + + assert_eq!(actual, expected); + + let expr = col("a").is_null(); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + let expected = r#"a IS NULL"#.to_string(); + + assert_eq!(actual, expected); + + Ok(()) + } } From 1a09adf9d093405dd9465f98efaf626198f96d16 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 19 Nov 2024 10:59:29 +0900 Subject: [PATCH 26/45] Unparse inner join with no conditions as a cross join (#13460) * Unparse inner join with no conditions as a cross join * Add explicit cross join * Fix mysql test --- datafusion/sql/src/unparser/plan.rs | 11 ++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 16 +++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 433c456855a30..81e47ed939f22 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -876,7 +876,16 @@ impl Unparser<'_> { constraint: ast::JoinConstraint, ) -> Result { Ok(match join_type { - JoinType::Inner => ast::JoinOperator::Inner(constraint), + JoinType::Inner => match &constraint { + ast::JoinConstraint::On(_) + | ast::JoinConstraint::Using(_) + | ast::JoinConstraint::Natural => ast::JoinOperator::Inner(constraint), + ast::JoinConstraint::None => { + // Inner joins with no conditions or filters are not valid SQL in most systems, + // return a CROSS JOIN instead + ast::JoinOperator::CrossJoin + } + }, JoinType::Left => ast::JoinOperator::LeftOuter(constraint), JoinType::Right => ast::JoinOperator::RightOuter(constraint), JoinType::Full => ast::JoinOperator::FullOuter(constraint), diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 4f43d7333dd14..f9d97cdc74af9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -296,7 +296,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { TestStatementWithDialect { sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", expected: - "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), }, @@ -1253,3 +1253,17 @@ fn test_unnest_to_sql() { r#"SELECT UNNEST([1, 2, 2, 5, NULL]) AS u1"#, ); } + +#[test] +fn test_join_with_no_conditions() { + sql_round_trip( + GenericDialect {}, + "SELECT * FROM j1 JOIN j2", + "SELECT * FROM j1 CROSS JOIN j2", + ); + sql_round_trip( + GenericDialect {}, + "SELECT * FROM j1 CROSS JOIN j2", + "SELECT * FROM j1 CROSS JOIN j2", + ); +} From f3023835ebb913208816568bedef225e45a16d0b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 19 Nov 2024 11:35:14 +0800 Subject: [PATCH 27/45] Remove unreachable filter logic in final grouping stage (#13463) * rm filtere in final grouping stage Signed-off-by: jayzhan211 * add comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/functions-aggregate/src/count.rs | 24 ++++------- .../functions-aggregate/src/variance.rs | 42 ++++++------------- .../physical-plan/src/aggregates/row_hash.rs | 11 +++-- 3 files changed, 24 insertions(+), 53 deletions(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 52181372698f2..8fdd702b5b7c6 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -467,7 +467,8 @@ impl GroupsAccumulator for CountGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&BooleanArray>, + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "one argument to merge_batch"); @@ -480,22 +481,11 @@ impl GroupsAccumulator for CountGroupsAccumulator { // Adds the counts with the partial counts self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } + group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ); Ok(()) } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 8daa85a5cc834..55d4181a96dfe 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -460,7 +460,7 @@ impl VarianceGroupsAccumulator { counts: &UInt64Array, means: &Float64Array, m2s: &Float64Array, - opt_filter: Option<&BooleanArray>, + _opt_filter: Option<&BooleanArray>, mut value_fn: F, ) where F: FnMut(usize, u64, f64, f64) + Send, @@ -469,33 +469,14 @@ impl VarianceGroupsAccumulator { assert_eq!(means.null_count(), 0); assert_eq!(m2s.null_count(), 0); - match opt_filter { - None => { - group_indices - .iter() - .zip(counts.values().iter()) - .zip(means.values().iter()) - .zip(m2s.values().iter()) - .for_each(|(((&group_index, &count), &mean), &m2)| { - value_fn(group_index, count, mean, m2); - }); - } - Some(filter) => { - group_indices - .iter() - .zip(counts.values().iter()) - .zip(means.values().iter()) - .zip(m2s.values().iter()) - .zip(filter.iter()) - .for_each( - |((((&group_index, &count), &mean), &m2), filter_value)| { - if let Some(true) = filter_value { - value_fn(group_index, count, mean, m2); - } - }, - ); - } - } + group_indices + .iter() + .zip(counts.values().iter()) + .zip(means.values().iter()) + .zip(m2s.values().iter()) + .for_each(|(((&group_index, &count), &mean), &m2)| { + value_fn(group_index, count, mean, m2); + }); } pub fn variance( @@ -554,7 +535,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&BooleanArray>, + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); @@ -569,7 +551,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { partial_counts, partial_means, partial_m2s, - opt_filter, + None, |group_index, partial_count, partial_mean, partial_m2| { if partial_count == 0 { return; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 0fa9f206f13db..965adbb8c7804 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -859,14 +859,13 @@ impl GroupedHashAggregateStream { )?; } _ => { + if opt_filter.is_some() { + return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage"); + } + // if aggregation is over intermediate states, // use merge - acc.merge_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; + acc.merge_batch(values, group_indices, None, total_num_groups)?; } } } From c44b61320b68f952dca101769c450f912e466f5e Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 19 Nov 2024 19:50:09 +0800 Subject: [PATCH 28/45] MINOR: remove one duplicated inparam in TopK (#13479) Signed-off-by: Ruihang Xia --- datafusion/physical-plan/src/sorts/sort.rs | 1 - datafusion/physical-plan/src/topk/mod.rs | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index e9f17ddebabc9..2e97334493dd6 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -932,7 +932,6 @@ impl ExecutionPlan for SortExec { context.session_config().batch_size(), context.runtime_env(), &self.metrics_set, - partition, )?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 27bb3b2b36b97..0f722ec143ff7 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -95,9 +95,7 @@ pub struct TopK { impl TopK { /// Create a new [`TopK`] that stores the top `k` values, as /// defined by the sort expressions in `expr`. - // TODO: make a builder or some other nicer API to avoid the - // clippy warning - #[allow(clippy::too_many_arguments)] + // TODO: make a builder or some other nicer API pub fn try_new( partition_id: usize, schema: SchemaRef, @@ -106,7 +104,6 @@ impl TopK { batch_size: usize, runtime: Arc, metrics: &ExecutionPlanMetricsSet, - partition: usize, ) -> Result { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); @@ -133,7 +130,7 @@ impl TopK { Ok(Self { schema: Arc::clone(&schema), - metrics: TopKMetrics::new(metrics, partition), + metrics: TopKMetrics::new(metrics, partition_id), reservation, batch_size, expr, From 9fb5ff99350cea8d360a0519ad9abb8046770973 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 19 Nov 2024 13:01:58 +0100 Subject: [PATCH 29/45] Fix join on arrays of unhashable types and allow hash join on all types supported at run-time (#13388) * Remove unused code paths from create_hashes The `downcast_primitive_array!` macro handles all primitive types and only then delegates to fallbacks. It handles Decimal128 and Decimal256 internally. * Fix join on arrays of unhashable types and allow hash join on all types supported at run-time #13388 Update can_hash to match currently supported hashes. * Rename table_with_many_types in tests * Test join on binary is hash join --- datafusion/common/src/hash_utils.rs | 10 +--- datafusion/expr/src/utils.rs | 59 +++++++++++++------ datafusion/sqllogictest/src/test_context.rs | 9 ++- .../test_files/information_schema_columns.slt | 16 ++--- datafusion/sqllogictest/test_files/joins.slt | 21 +++++++ 5 files changed, 79 insertions(+), 36 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 8bd646626e068..e18d70844d32b 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -32,7 +32,7 @@ use arrow_buffer::IntervalMonthDayNano; use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_primitive_array, as_string_array, as_string_view_array, as_struct_array, + as_string_array, as_string_view_array, as_struct_array, }; use crate::error::Result; #[cfg(not(feature = "force_hash_collisions"))] @@ -392,14 +392,6 @@ pub fn create_hashes<'a>( let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); hash_array(array, random_state, hashes_buffer, rehash) } - DataType::Decimal128(_, _) => { - let array = as_primitive_array::(array)?; - hash_array_primitive(array, random_state, hashes_buffer, rehash) - } - DataType::Decimal256(_, _) => { - let array = as_primitive_array::(array)?; - hash_array_primitive(array, random_state, hashes_buffer, rehash) - } DataType::Dictionary(_, _) => downcast_dictionary_array! { array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c22ee244fe286..6f7c5d379260e 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -29,7 +29,7 @@ use crate::{ }; use datafusion_expr_common::signature::{Signature, TypeSignature}; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -958,7 +958,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( /// Can this data type be used in hash join equal conditions?? /// Data types here come from function 'equal_rows', if more data types are supported -/// in equal_rows(hash join), add those data types here to generate join logical plan. +/// in create_hashes, add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { DataType::Null => true, @@ -971,31 +971,38 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt16 => true, DataType::UInt32 => true, DataType::UInt64 => true, + DataType::Float16 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => true, - TimeUnit::Millisecond => true, - TimeUnit::Microsecond => true, - TimeUnit::Nanosecond => true, - }, + DataType::Decimal128(_, _) => true, + DataType::Decimal256(_, _) => true, + DataType::Timestamp(_, _) => true, DataType::Utf8 => true, DataType::LargeUtf8 => true, DataType::Utf8View => true, - DataType::Decimal128(_, _) => true, + DataType::Binary => true, + DataType::LargeBinary => true, + DataType::BinaryView => true, DataType::Date32 => true, DataType::Date64 => true, + DataType::Time32(_) => true, + DataType::Time64(_) => true, + DataType::Duration(_) => true, + DataType::Interval(_) => true, DataType::FixedSizeBinary(_) => true, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - DataType::is_dictionary_key_type(key_type) + DataType::Dictionary(key_type, value_type) => { + DataType::is_dictionary_key_type(key_type) && can_hash(value_type) } - DataType::List(_) => true, - DataType::LargeList(_) => true, - DataType::FixedSizeList(_, _) => true, + DataType::List(value_type) => can_hash(value_type.data_type()), + DataType::LargeList(value_type) => can_hash(value_type.data_type()), + DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()), + DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()), DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), - _ => false, + + DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Union(_, _) + | DataType::RunEndEncoded(_, _) => false, } } @@ -1403,6 +1410,7 @@ mod tests { test::function_stub::max_udaf, test::function_stub::min_udaf, test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; + use arrow::datatypes::{UnionFields, UnionMode}; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1805,4 +1813,21 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_can_hash() { + let union_fields: UnionFields = [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect(); + + let union_type = DataType::Union(union_fields, UnionMode::Sparse); + assert!(!can_hash(&union_type)); + + let list_union_type = + DataType::List(Arc::new(Field::new("my_union", union_type, true))); + assert!(!can_hash(&list_union_type)); + } } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 477f225443e28..2466303c32a97 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -106,6 +106,8 @@ impl TestContext { let example_udf = create_example_udf(); test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; + info!("Registering table with many types"); + register_table_with_many_types(test_ctx.session_ctx()).await; } "metadata.slt" => { info!("Registering metadata table tables"); @@ -251,8 +253,11 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) { .unwrap(); ctx.register_catalog("my_catalog", Arc::new(catalog)); - ctx.register_table("my_catalog.my_schema.t2", table_with_many_types()) - .unwrap(); + ctx.register_table( + "my_catalog.my_schema.table_with_many_types", + table_with_many_types(), + ) + .unwrap(); } pub async fn register_table_with_map(ctx: &SessionContext) { diff --git a/datafusion/sqllogictest/test_files/information_schema_columns.slt b/datafusion/sqllogictest/test_files/information_schema_columns.slt index 7cf845c16d738..d348a764fa85f 100644 --- a/datafusion/sqllogictest/test_files/information_schema_columns.slt +++ b/datafusion/sqllogictest/test_files/information_schema_columns.slt @@ -37,17 +37,17 @@ query TTTTITTTIIIIIIT rowsort SELECT * from information_schema.columns; ---- my_catalog my_schema t1 i 0 NULL YES Int32 NULL NULL 32 2 NULL NULL NULL -my_catalog my_schema t2 binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL -my_catalog my_schema t2 float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL -my_catalog my_schema t2 int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL -my_catalog my_schema t2 large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL -my_catalog my_schema t2 large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL -my_catalog my_schema t2 timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL -my_catalog my_schema t2 utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL +my_catalog my_schema table_with_many_types int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL +my_catalog my_schema table_with_many_types large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL # Cleanup statement ok drop table t1 statement ok -drop table t2 +drop table table_with_many_types diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index d45dbc7ee1ae5..e636e93007a4a 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4292,3 +4292,24 @@ query T select * from table1 as t1 natural join table1_stringview as t2; ---- foo + +query TT +EXPLAIN SELECT count(*) +FROM my_catalog.my_schema.table_with_many_types AS l +JOIN my_catalog.my_schema.table_with_many_types AS r ON l.binary_col = r.binary_col +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] +02)--Projection: +03)----Inner Join: l.binary_col = r.binary_col +04)------SubqueryAlias: l +05)--------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] +06)------SubqueryAlias: r +07)--------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[count(*)] +02)--ProjectionExec: expr=[] +03)----CoalesceBatchesExec: target_batch_size=3 +04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)--------MemoryExec: partitions=1, partition_sizes=[1] From 398d5f653be57963f1f5bd26231912fe3d967fc6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 19 Nov 2024 09:58:44 -0500 Subject: [PATCH 30/45] Minor: Fix broken links for meetups in content library (#13445) * Minor: Fix broken links for meetups in content library * prettier * fix links, add Amsterdam meetup * Update docs/source/user-guide/concepts-readings-events.md Co-authored-by: Jonah Gao --------- Co-authored-by: Jonah Gao --- docs/source/user-guide/concepts-readings-events.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/concepts-readings-events.md b/docs/source/user-guide/concepts-readings-events.md index 092f8433d47b5..135fbc47ad904 100644 --- a/docs/source/user-guide/concepts-readings-events.md +++ b/docs/source/user-guide/concepts-readings-events.md @@ -131,10 +131,11 @@ This is a list of DataFusion related blog posts, articles, and other resources. # 🌎 Community Events +- **2025-01-25** (Upcoming) [Amsterdam Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/12988) - **2025-01-15** (Upcoming) [Boston Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/13165) - **2024-12-18** (Upcoming) [Chicago Apache DataFusion Meetup](https://lu.ma/eq5myc5i) +- **2024-10-14** [Seattle Apache DataFusion Meetup](https://lu.ma/tnwl866b) - **2024-09-27** [Belgrade Apache DataFusion Meetup](https://lu.ma/tmwuz4lg), [recap](https://github.com/apache/datafusion/discussions/11431#discussioncomment-10832070), [slides](https://github.com/apache/datafusion/discussions/11431#discussioncomment-10826169), [recordings](https://www.youtube.com/watch?v=4huEsFFv6bQ&list=PLrhIfEjaw9ilQEczOQlHyMznabtVRptyX) - **2024-06-26** [New York City Apache DataFusion Meetup](https://lu.ma/2iwba0xm). [slides](https://docs.google.com/presentation/d/1dOLPAFPEMLhLv4NN6O9QSDIyyeiIySqAjky5cVgdWAE/edit#slide=id.g26bebde4fcc_3_7) - **2024-06-25** [San Francisco Bay Area Apache DataFusion Meetup](https://lu.ma/6bphole2). [slides](https://docs.google.com/presentation/d/1Oz2yGllrWBkNGyiRMLr8qXTt4vmvtJWuI_weGThaZak/edit#slide=id.g26bebde4fcc_3_7) - **2024-03-27** [Austin Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/8522). [slides](https://docs.google.com/presentation/d/1S51TK8waxHEJaxi_-uiSMrgQZ09m_hfaasPk5X5ExEY), [recording](https://www.youtube.com/watch?v=q1N3pH3tFw8) -- **2024-03-26** [Seattle Apache DataFusion Meetup]( From 96d76ab8e095c4c4d034e1b261761fadb5e3be99 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 19 Nov 2024 17:30:57 +0100 Subject: [PATCH 31/45] Remove redundant dead_code check suppressions (#13490) --- datafusion/core/src/datasource/physical_plan/arrow_file.rs | 1 - datafusion/core/src/datasource/physical_plan/avro.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index df5ede5e8391e..8df5ef82cd0c8 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -46,7 +46,6 @@ use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; /// Execution plan for scanning Arrow data source #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct ArrowExec { base_config: FileScanConfig, projected_statistics: Statistics, diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 2e83be212f8b9..68d219ef0e5eb 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -34,7 +34,6 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; /// Execution plan for scanning Avro data source #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct AvroExec { base_config: FileScanConfig, projected_statistics: Statistics, From c3681dc865f1fd3bde1e08d44c83bf7e1079464a Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 19 Nov 2024 23:46:14 +0100 Subject: [PATCH 32/45] chore: try make Setup Rust CI step immune to network hang (#13495) Try add a timeout to the apt-get invocation so that retry can kick in. --- .github/actions/setup-builder/action.yaml | 14 +++++++------- ci/scripts/retry | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0f45d51835f41..22d2f2187dd07 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -28,18 +28,18 @@ runs: - name: Install Build Dependencies shell: bash run: | - RETRY="ci/scripts/retry" - "${RETRY}" apt-get update - "${RETRY}" apt-get install -y protobuf-compiler + RETRY=("ci/scripts/retry" timeout 120) + "${RETRY[@]}" apt-get update + "${RETRY[@]}" apt-get install -y protobuf-compiler - name: Setup Rust toolchain shell: bash # rustfmt is needed for the substrait build script run: | - RETRY="ci/scripts/retry" + RETRY=("ci/scripts/retry" timeout 120) echo "Installing ${{ inputs.rust-version }}" - "${RETRY}" rustup toolchain install ${{ inputs.rust-version }} - "${RETRY}" rustup default ${{ inputs.rust-version }} - "${RETRY}" rustup component add rustfmt + "${RETRY[@]}" rustup toolchain install ${{ inputs.rust-version }} + "${RETRY[@]}" rustup default ${{ inputs.rust-version }} + "${RETRY[@]}" rustup component add rustfmt - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions diff --git a/ci/scripts/retry b/ci/scripts/retry index 0569dea58c94a..411dc532ca52f 100755 --- a/ci/scripts/retry +++ b/ci/scripts/retry @@ -7,7 +7,7 @@ x() { "$@" } -max_retry_time_seconds=$(( 3 * 60 )) +max_retry_time_seconds=$(( 5 * 60 )) retry_delay_seconds=10 END=$(( $(date +%s) + ${max_retry_time_seconds} )) From 30ff48e94c416387f52b852b15366b887b9c9fb2 Mon Sep 17 00:00:00 2001 From: irenjj Date: Wed, 20 Nov 2024 10:50:20 +0800 Subject: [PATCH 33/45] Move `Pruning` into `physical-optimizer` crate (#13485) * Move `Pruning` into `physical-optimizer` crate * fix check * fix issues * cargo update --- datafusion-cli/Cargo.lock | 64 ++++++++++--------- datafusion/core/src/physical_optimizer/mod.rs | 1 - datafusion/physical-optimizer/Cargo.toml | 4 ++ datafusion/physical-optimizer/src/lib.rs | 1 + .../src}/pruning.rs | 40 ++++++------ 5 files changed, 56 insertions(+), 54 deletions(-) rename datafusion/{core/src/physical_optimizer => physical-optimizer/src}/pruning.rs (99%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index bfd0411798c96..c5576b7e7d444 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.49.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53dcf5e7d9bd1517b8b998e170e650047cea8a2b85fe1835abe3210713e541b7" +checksum = "6ada54e5f26ac246dc79727def52f7f8ed38915cb47781e2a72213957dc3a7d5" dependencies = [ "aws-credential-types", "aws-runtime", @@ -857,9 +857,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" dependencies = [ "memchr", "regex-automata", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -980,9 +980,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -990,9 +990,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -1014,9 +1014,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "clipboard-win" @@ -1035,9 +1035,9 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "comfy-table" -version = "7.1.2" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ "strum 0.26.3", "strum_macros 0.26.4", @@ -1537,9 +1537,11 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr-common", + "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-plan", "itertools", + "log", "recursive", ] @@ -1749,9 +1751,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -2162,7 +2164,7 @@ dependencies = [ "http 1.1.0", "hyper 1.5.0", "hyper-util", - "rustls 0.23.16", + "rustls 0.23.17", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2484,9 +2486,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libflate" @@ -3073,7 +3075,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.16", + "rustls 0.23.17", "socket2", "thiserror 2.0.3", "tokio", @@ -3091,7 +3093,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.16", + "rustls 0.23.17", "rustls-pki-types", "slab", "thiserror 2.0.3", @@ -3269,7 +3271,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.16", + "rustls 0.23.17", "rustls-native-certs 0.8.0", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3363,9 +3365,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.40" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -3388,9 +3390,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "once_cell", "ring", @@ -3518,9 +3520,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -3598,9 +3600,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -4019,7 +4021,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.16", + "rustls 0.23.17", "rustls-pki-types", "tokio", ] diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index a9f6f30dc1753..000c27effdb69 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -27,7 +27,6 @@ pub mod enforce_sorting; pub mod join_selection; pub mod optimizer; pub mod projection_pushdown; -pub mod pruning; pub mod replace_with_order_preserving_variants; pub mod sanity_checker; #[cfg(test)] diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 04f01f8badb86..718567de8df4f 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -37,11 +37,15 @@ arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } +datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } itertools = { workspace = true } +log = { workspace = true } recursive = { workspace = true } [dev-dependencies] +datafusion-expr = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-nested = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 5d0ccde9f8cdc..c4f5fa74e1225 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -25,6 +25,7 @@ pub mod limit_pushdown; pub mod limited_distinct_aggregation; mod optimizer; pub mod output_requirements; +pub mod pruning; pub mod topk_aggregation; pub mod update_aggr_exprs; diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs similarity index 99% rename from datafusion/core/src/physical_optimizer/pruning.rs rename to datafusion/physical-optimizer/src/pruning.rs index 89b86471561ed..3cfb03b7205a5 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -18,33 +18,30 @@ //! [`PruningPredicate`] to apply filter [`Expr`] to prune "containers" //! based on statistics (e.g. Parquet Row Groups) //! -//! [`Expr`]: crate::prelude::Expr +//! [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html use std::collections::HashSet; use std::sync::Arc; -use crate::{ - common::{Column, DFSchema}, - error::{DataFusionError, Result}, - logical_expr::Operator, - physical_plan::{ColumnarValue, PhysicalExpr}, -}; - +use arrow::array::AsArray; use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use arrow_array::cast::AsArray; +use log::trace; + +use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, ScalarValue, }; +use datafusion_common::{Column, DFSchema}; +use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; - -use log::trace; +use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; /// A source of runtime statistical information to [`PruningPredicate`]s. /// @@ -567,7 +564,7 @@ impl PruningPredicate { /// expressions like `b = false`, but it does handle the /// simplified version `b`. See [`ExprSimplifier`] to simplify expressions. /// - /// [`ExprSimplifier`]: crate::optimizer::simplify_expressions::ExprSimplifier + /// [`ExprSimplifier`]: https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html pub fn prune(&self, statistics: &S) -> Result> { let mut builder = BoolVecBuilder::new(statistics.num_containers()); @@ -653,7 +650,7 @@ impl PruningPredicate { // this is only used by `parquet` feature right now #[allow(dead_code)] - pub(crate) fn required_columns(&self) -> &RequiredColumns { + pub fn required_columns(&self) -> &RequiredColumns { &self.required_columns } @@ -762,7 +759,7 @@ fn is_always_true(expr: &Arc) -> bool { /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredColumns { +pub struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) @@ -786,7 +783,7 @@ impl RequiredColumns { /// * `true` returns None #[allow(dead_code)] // this fn is only used by `parquet` feature right now, thus the `allow(dead_code)` - pub(crate) fn single_column(&self) -> Option<&phys_expr::Column> { + pub fn single_column(&self) -> Option<&phys_expr::Column> { if self.columns.windows(2).all(|w| { // check if all columns are the same (ignoring statistics and field) let c1 = &w[0].0; @@ -1664,15 +1661,14 @@ mod tests { use std::ops::{Not, Rem}; use super::*; - use crate::assert_batches_eq; - use crate::logical_expr::{col, lit}; + use datafusion_common::assert_batches_eq; + use datafusion_expr::{col, lit}; use arrow::array::Decimal128Array; use arrow::{ - array::{BinaryArray, Int32Array, Int64Array, StringArray}, + array::{BinaryArray, Int32Array, Int64Array, StringArray, UInt64Array}, datatypes::TimeUnit, }; - use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; use datafusion_functions_nested::expr_fn::{array_has, make_array}; @@ -3536,7 +3532,7 @@ mod tests { // more complex case with unknown column let input = known_expression.clone().and(input.clone()); let expected = phys_expr::BinaryExpr::new( - known_expression_transformed.clone(), + Arc::::clone(&known_expression_transformed), Operator::And, logical2physical(&lit(42), &schema), ); @@ -3552,7 +3548,7 @@ mod tests { // more complex case with unknown expression let input = known_expression.and(input); let expected = phys_expr::BinaryExpr::new( - known_expression_transformed.clone(), + Arc::::clone(&known_expression_transformed), Operator::And, logical2physical(&lit(42), &schema), ); @@ -4038,7 +4034,7 @@ mod tests { ) { println!("Pruning with expr: {}", expr); let expr = logical2physical(&expr, schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); assert_eq!(result, expected); } From aef232b1ac559fc1597a64d9f5f75c2f29f4c286 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 20 Nov 2024 08:29:54 +0100 Subject: [PATCH 34/45] Add `Container` trait and to simplify `Expr` and `LogicalPlan` apply and map methods (#13467) * Add `Container` trait and its blanket implementations, remove `map_until_stop_and_collect` macro, simplify apply and map logic with `Container`s where possible * fix clippy * rename `Container` to `TreeNodeContainer` * add docs to containers * clarify when we need a temporary `TreeNodeRefContainer` * code and docs cleanup --- datafusion/common/src/tree_node.rs | 363 ++++++++++++++--- datafusion/expr/src/expr.rs | 36 +- datafusion/expr/src/logical_plan/ddl.rs | 50 ++- datafusion/expr/src/logical_plan/plan.rs | 20 +- datafusion/expr/src/logical_plan/statement.rs | 51 +-- datafusion/expr/src/logical_plan/tree_node.rs | 347 +++++++--------- datafusion/expr/src/tree_node.rs | 372 ++++++------------ .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/sql/src/unparser/rewrite.rs | 24 +- 9 files changed, 687 insertions(+), 580 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c8ec7f18339a8..0c153583e34b1 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,11 +17,12 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees +use crate::Result; use recursive::recursive; +use std::collections::HashMap; +use std::hash::Hash; use std::sync::Arc; -use crate::Result; - /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ @@ -769,6 +770,297 @@ impl Transformed { } } +/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped. +/// The elements of the container are siblings so the continuation rules are similar to +/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`]. +pub trait TreeNodeContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result; + + /// Maps all elements of the container with `f`. + /// This method is usually called from [`TreeNode::map_children`] implementations as + /// a node is actually a container of the node's children. + fn map_elements Result>>( + self, + f: F, + ) -> Result>; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + Arc::unwrap_or_clone(self) + .map_elements(f)? + .map_data(|c| Ok(Arc::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + match self { + Some(t) => t.apply_elements(f), + None => Ok(TreeNodeRecursion::Continue), + } + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.map_or(Ok(Transformed::no(None)), |c| { + c.map_elements(f)?.map_data(|c| Ok(Some(c))) + }) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|c| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(c), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> + for HashMap +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self.values() { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|(k, c)| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + (k, result.data) + }) + } + TreeNodeRecursion::Stop => Ok((k, c)), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeContainer<'a, T> for (C0, C1) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1)))? + .transform_sibling(|(new_c0, c1)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1))) + }) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeContainer<'a, T> for (C0, C1, C2) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1, self.2)))? + .transform_sibling(|(new_c0, c1, c2)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1, c2))) + })? + .transform_sibling(|(new_c0, new_c1, c2)| { + c2.map_elements(&mut f)? + .map_data(|new_c2| Ok((new_c0, new_c1, new_c2))) + }) + } +} + +/// [`TreeNodeRefContainer`] contains references to elements that a function can be +/// applied on. The elements of the container are siblings so the continuation rules are +/// similar to [`TreeNodeRecursion::visit_sibling`]. +/// +/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference +/// elements (`T`) are not derived from the container's lifetime. +/// A typical usage of this container is in `Expr::apply_children` when we need to +/// construct a temporary container to be able to call `apply_ref_elements` on a +/// collection of tree node references. But in that case the container's temporary +/// lifetime is different to the lifetime of tree nodes that we put into it. +/// Please find an example usecase in `Expr::apply_children` with the `Expr::Case` case. +/// +/// Most of the cases we don't need to create a temporary container with +/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`. +/// Please find an example usecase in `Expr::apply_children` with the `Expr::GroupingSet` +/// case. +pub trait TreeNodeRefContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_ref_elements Result>( + &self, + f: F, + ) -> Result; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> { + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } +} + /// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator @@ -843,50 +1135,6 @@ impl TreeNodeIterator for I { } } -/// Transformation helper to process a heterogeneous sequence of tree node containing -/// expressions. -/// -/// This macro is very similar to [TreeNodeIterator::map_until_stop_and_collect] to -/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and -/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its -/// transformation (`F`). -/// -/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the -/// first element and further elements from the sequence of pairs. An element from a pair -/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on -/// the `Transformed.tnr` result of previous `F`s (`F0` initially). -/// -/// # Returns -/// Error if any of the transformations returns an error -/// -/// Ok(Transformed<(data0, ..., dataN)>) such that: -/// 1. `transformed` is true if any of the transformations had transformed true -/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and -/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F` -/// 3. `tnr` from `F0` or the last invocation of `F` -#[macro_export] -macro_rules! map_until_stop_and_collect { - ($F0:expr, $($EXPR:expr, $F:expr),*) => {{ - $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| { - let all_datas = ( - data0, - $( - if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump { - $F.map(|result| { - tnr = result.tnr; - transformed |= result.transformed; - result.data - })? - } else { - $EXPR - }, - )* - ); - Ok(Transformed::new(all_datas, transformed, tnr)) - }) - }} -} - /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -1021,7 +1269,7 @@ pub(crate) mod tests { use std::fmt::Display; use crate::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use crate::Result; @@ -1054,7 +1302,7 @@ pub(crate) mod tests { &'n self, f: F, ) -> Result { - self.children.iter().apply_until_stop(f) + self.children.apply_elements(f) } fn map_children Result>>( @@ -1063,8 +1311,7 @@ pub(crate) mod tests { ) -> Result> { Ok(self .children - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|new_children| Self { children: new_children, ..self @@ -1072,6 +1319,22 @@ pub(crate) mod tests { } } + impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } + } + // J // | // I diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 83d35c3d25b16..8490c08a70bbb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,7 +32,7 @@ use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference, @@ -351,6 +351,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr { } } +impl<'a> TreeNodeContainer<'a, Self> for Expr { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -653,6 +669,24 @@ impl Display for Sort { } } +impl<'a> TreeNodeContainer<'a, Expr> for Sort { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.expr + .map_elements(f)? + .map_data(|expr| Ok(Self { expr, ..self })) + } +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 93e8b5fd045e7..8c64a017988e9 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -26,7 +26,10 @@ use std::{ use crate::expr::Sort; use arrow::datatypes::DataType; -use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; +use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; +use datafusion_common::{ + Constraints, DFSchemaRef, Result, SchemaReference, TableReference, +}; use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation @@ -487,6 +490,28 @@ pub struct OperateFunctionArg { pub data_type: DataType, pub default_expr: Option, } + +impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.default_expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.default_expr.map_elements(f)?.map_data(|default_expr| { + Ok(Self { + default_expr, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct CreateFunctionBody { /// LANGUAGE lang_name @@ -497,6 +522,29 @@ pub struct CreateFunctionBody { pub function_body: Option, } +impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.function_body.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.function_body + .map_elements(f)? + .map_data(|function_body| { + Ok(Self { + function_body, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct DropFunction { pub name: String, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6ee99b22c7f3c..e9f4f1f80972d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,7 +45,9 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, +}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, @@ -287,6 +289,22 @@ impl Default for LogicalPlan { } } +impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 05e2b1af14d3b..26df379f5e4ad 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -16,12 +16,10 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::tree_node::{Transformed, TreeNodeIterator}; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_common::{DFSchema, DFSchemaRef}; use std::fmt::{self, Display}; use std::sync::{Arc, OnceLock}; -use super::tree_node::rewrite_arc; use crate::{expr_vec_fmt, Expr, LogicalPlan}; /// Statements have a unchanging empty schema. @@ -80,53 +78,6 @@ impl Statement { } } - /// Rewrites input LogicalPlans in the current `Statement` using `f`. - pub(super) fn map_inputs< - F: FnMut(LogicalPlan) -> Result>, - >( - self, - f: F, - ) -> Result> { - match self { - Statement::Prepare(Prepare { - input, - name, - data_types, - }) => Ok(rewrite_arc(input, f)?.update_data(|input| { - Statement::Prepare(Prepare { - input, - name, - data_types, - }) - })), - _ => Ok(Transformed::no(self)), - } - } - - /// Returns a iterator over all expressions in the current `Statement`. - pub(super) fn expression_iter(&self) -> impl Iterator { - match self { - Statement::Execute(Execute { parameters, .. }) => parameters.iter(), - _ => [].iter(), - } - } - - /// Rewrites all expressions in the current `Statement` using `f`. - pub(super) fn map_expressions Result>>( - self, - f: F, - ) -> Result> { - match self { - Statement::Execute(Execute { name, parameters }) => Ok(parameters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|parameters| { - Statement::Execute(Execute { parameters, name }) - })), - _ => Ok(Transformed::no(self)), - } - } - /// Return a `format`able structure with the a human readable /// description of this LogicalPlan node per node, not including /// children. diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index e7dfe87919241..6850c30f4f81b 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -36,32 +36,30 @@ //! (Re)creation APIs (these require substantial cloning and thus are slow): //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions + use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, - Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, - LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, Subquery, - SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, + Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, + Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, + Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + UserDefinedLogicalNode, Values, Window, }; +use datafusion_common::tree_node::TreeNodeRefContainer; use recursive::recursive; -use std::ops::Deref; -use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; -use datafusion_common::{ - internal_err, map_until_stop_and_collect, DataFusionError, Result, + Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; +use datafusion_common::{internal_err, Result}; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, ) -> Result { - self.inputs().into_iter().apply_until_stop(f) + self.inputs().apply_ref_elements(f) } /// Applies `f` to each child (input) of this plan node, rewriting them *in place.* @@ -74,14 +72,14 @@ impl TreeNode for LogicalPlan { /// [`Expr::Exists`]: crate::Expr::Exists fn map_children Result>>( self, - mut f: F, + f: F, ) -> Result> { Ok(match self { LogicalPlan::Projection(Projection { expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Projection(Projection { expr, input, @@ -92,7 +90,7 @@ impl TreeNode for LogicalPlan { predicate, input, having, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Filter(Filter { predicate, input, @@ -102,7 +100,7 @@ impl TreeNode for LogicalPlan { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -112,7 +110,7 @@ impl TreeNode for LogicalPlan { input, window_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Window(Window { input, window_expr, @@ -124,7 +122,7 @@ impl TreeNode for LogicalPlan { group_expr, aggr_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Aggregate(Aggregate { input, group_expr, @@ -132,7 +130,8 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => rewrite_arc(input, f)? + LogicalPlan::Sort(Sort { expr, input, fetch }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Join(Join { left, @@ -143,12 +142,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { + }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, right, @@ -160,12 +154,13 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? + LogicalPlan::Limit(Limit { skip, fetch, input }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, - }) => rewrite_arc(subquery, f)?.update_data(|subquery| { + }) => subquery.map_elements(f)?.update_data(|subquery| { LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, @@ -175,7 +170,7 @@ impl TreeNode for LogicalPlan { input, alias, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, @@ -184,17 +179,18 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? .update_data(LogicalPlan::Extension), - LogicalPlan::Union(Union { inputs, schema }) => rewrite_arcs(inputs, f)? + LogicalPlan::Union(Union { inputs, schema }) => inputs + .map_elements(f)? .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), LogicalPlan::Distinct(distinct) => match distinct { - Distinct::All(input) => rewrite_arc(input, f)?.update_data(Distinct::All), + Distinct::All(input) => input.map_elements(f)?.update_data(Distinct::All), Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { Distinct::On(DistinctOn { on_expr, select_expr, @@ -211,7 +207,7 @@ impl TreeNode for LogicalPlan { stringified_plans, schema, logical_optimization_succeeded, - }) => rewrite_arc(plan, f)?.update_data(|plan| { + }) => plan.map_elements(f)?.update_data(|plan| { LogicalPlan::Explain(Explain { verbose, plan, @@ -224,7 +220,7 @@ impl TreeNode for LogicalPlan { verbose, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Analyze(Analyze { verbose, input, @@ -237,7 +233,7 @@ impl TreeNode for LogicalPlan { op, input, output_schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, table_schema, @@ -252,7 +248,7 @@ impl TreeNode for LogicalPlan { partition_by, file_type, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, @@ -271,7 +267,7 @@ impl TreeNode for LogicalPlan { or_replace, column_defaults, temporary, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, constraints, @@ -288,7 +284,7 @@ impl TreeNode for LogicalPlan { or_replace, definition, temporary, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, @@ -318,7 +314,7 @@ impl TreeNode for LogicalPlan { dependency_indices, schema, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Unnest(Unnest { input, exec_columns: input_columns, @@ -334,22 +330,24 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - }) => map_until_stop_and_collect!( - rewrite_arc(static_term, &mut f), - recursive_term, - rewrite_arc(recursive_term, &mut f) - )? - .update_data(|(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) - }), - LogicalPlan::Statement(stmt) => { - stmt.map_inputs(f)?.update_data(LogicalPlan::Statement) + }) => (static_term, recursive_term).map_elements(f)?.update_data( + |(static_term, recursive_term)| { + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }) + }, + ), + LogicalPlan::Statement(stmt) => match stmt { + Statement::Prepare(p) => p + .input + .map_elements(f)? + .update_data(|input| Statement::Prepare(Prepare { input, ..p })), + _ => Transformed::no(stmt), } + .update_data(LogicalPlan::Statement), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -359,24 +357,6 @@ impl TreeNode for LogicalPlan { } } -/// Applies `f` to rewrite a `Arc` without copying, if possible -pub(super) fn rewrite_arc Result>>( - plan: Arc, - mut f: F, -) -> Result>> { - f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) -} - -/// rewrite a `Vec` of `Arc` without copying, if possible -fn rewrite_arcs Result>>( - input_plans: Vec>, - mut f: F, -) -> Result>>> { - input_plans - .into_iter() - .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f)) -} - /// Rewrites all inputs for an Extension node "in place" /// (it currently has to copy values because there are no APIs for in place modification) /// @@ -423,54 +403,40 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().apply_until_stop(f) - } - LogicalPlan::Values(Values { values, .. }) => values - .iter() - .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), + LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), + LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { - expr.iter().apply_until_stop(f) + expr.apply_elements(f) } Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().apply_until_stop(f) + window_expr.apply_elements(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr - .iter() - .chain(aggr_expr.iter()) - .apply_until_stop(f), + }) => (group_expr, aggr_expr).apply_ref_elements(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { - on.iter() - // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... - // it not ideal to create an expr here to analyze them, but could cache it on the Join itself - .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .apply_until_stop(|e| f(&e))? - .visit_sibling(|| filter.iter().apply_until_stop(f)) - } - LogicalPlan::Sort(Sort { expr, .. }) => { - expr.iter().apply_until_stop(|sort| f(&sort.expr)) + (on, filter).apply_ref_elements(f) } + LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().apply_until_stop(f) + extension.node.expressions().apply_elements(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().apply_until_stop(f) + filters.apply_elements(f) } LogicalPlan::Unnest(unnest) => { let columns = unnest.exec_columns.clone(); @@ -479,24 +445,23 @@ impl LogicalPlan { .iter() .map(|c| Expr::Column(c.clone())) .collect::>(); - exprs.iter().apply_until_stop(f) + exprs.apply_elements(f) } LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, .. - })) => on_expr - .iter() - .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) - .apply_until_stop(f), - LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip - .iter() - .chain(fetch.iter()) - .map(|e| e.deref()) - .apply_until_stop(f), - LogicalPlan::Statement(stmt) => stmt.expression_iter().apply_until_stop(f), + })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + (skip, fetch).apply_ref_elements(f) + } + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(Execute { parameters, .. }) => { + parameters.apply_elements(f) + } + _ => Ok(TreeNodeRecursion::Continue), + }, // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -529,21 +494,15 @@ impl LogicalPlan { expr, input, schema, - }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) - }), + }) => expr.map_elements(f)?.update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), LogicalPlan::Values(Values { schema, values }) => values - .into_iter() - .map_until_stop_and_collect(|value| { - value.into_iter().map_until_stop_and_collect(&mut f) - })? + .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), LogicalPlan::Filter(Filter { predicate, @@ -561,12 +520,10 @@ impl LogicalPlan { partitioning_scheme, }) => match partitioning_scheme { Partitioning::Hash(expr, usize) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|expr| Partitioning::Hash(expr, usize)), Partitioning::DistributeBy(expr) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(Partitioning::DistributeBy), Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } @@ -580,34 +537,28 @@ impl LogicalPlan { input, window_expr, schema, - }) => window_expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|window_expr| { - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) - }), + }) => window_expr.map_elements(f)?.update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), LogicalPlan::Aggregate(Aggregate { input, group_expr, aggr_expr, schema, - }) => map_until_stop_and_collect!( - group_expr.into_iter().map_until_stop_and_collect(&mut f), - aggr_expr, - aggr_expr.into_iter().map_until_stop_and_collect(&mut f) - )? - .update_data(|(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) - }), + }) => (group_expr, aggr_expr).map_elements(f)?.update_data( + |(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }, + ), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. @@ -621,16 +572,7 @@ impl LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - on.into_iter().map_until_stop_and_collect( - |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) - ), - filter, - filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(on, filter)| { + }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, right, @@ -642,17 +584,13 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - transform_sort_vec(expr, &mut f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) - } + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .map_elements(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - let exprs = node - .expressions() - .into_iter() - .map_until_stop_and_collect(f)?; + let exprs = node.expressions().map_elements(f)?; let plan = LogicalPlan::Extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), @@ -669,64 +607,47 @@ impl LogicalPlan { projected_schema, filters, fetch, - }) => filters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|filters| { - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) - }), + }) => filters.map_elements(f)?.update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - })) => map_until_stop_and_collect!( - on_expr.into_iter().map_until_stop_and_collect(&mut f), - select_expr, - select_expr.into_iter().map_until_stop_and_collect(&mut f), - sort_expr, - transform_sort_option_vec(sort_expr, &mut f) - )? - .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) - }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => { - let skip = skip.map(|e| *e); - let fetch = fetch.map(|e| *e); - map_until_stop_and_collect!( - skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }), - fetch, - fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(skip, fetch)| { - LogicalPlan::Limit(Limit { - skip: skip.map(Box::new), - fetch: fetch.map(Box::new), + })) => (on_expr, select_expr, sort_expr) + .map_elements(f)? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, input, - }) + schema, + })) + }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { skip, fetch, input }) }) } - LogicalPlan::Statement(stmt) => { - stmt.map_expressions(f)?.update_data(LogicalPlan::Statement) + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(e) => { + e.parameters.map_elements(f)?.update_data(|parameters| { + Statement::Execute(Execute { parameters, ..e }) + }) + } + _ => Transformed::no(stmt), } + .update_data(LogicalPlan::Statement), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index e964091aae668..eacace5ed0461 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -19,14 +19,14 @@ use crate::expr::{ AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }; -use datafusion_common::{map_until_stop_and_collect, Result}; +use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -42,9 +42,9 @@ impl TreeNode for Expr { &'n self, f: F, ) -> Result { - let children = match self { - Expr::Alias(Alias{expr,..}) - | Expr::Unnest(Unnest{expr}) + match self { + Expr::Alias(Alias { expr, .. }) + | Expr::Unnest(Unnest { expr }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -57,78 +57,50 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], + | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.iter().collect() + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), + Expr::ScalarFunction(ScalarFunction { args, .. }) => { + args.apply_elements(f) } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.iter().flatten().collect() + lists_of_exprs.apply_elements(f) } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) - | Expr::Exists {..} + | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref(), right.as_ref()] + (left, right).apply_ref_elements(f) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref(), pattern.as_ref()] + (expr, pattern).apply_ref_elements(f) } Expr::Between(Between { - expr, low, high, .. - }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()], - Expr::Case(case) => { - let mut expr_vec = vec![]; - if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref()); - }; - for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref()); - expr_vec.push(then.as_ref()); - } - if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref()); - } - expr_vec - } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.iter().collect::>(); - if let Some(f) = filter { - expr_vec.push(f.as_ref()); - } - if let Some(order_by) = order_by { - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - } - expr_vec - } + expr, low, high, .. + }) => (expr, low, high).apply_ref_elements(f), + Expr::Case(Case { expr, when_then_expr, else_expr }) => + (expr, when_then_expr, else_expr).apply_ref_elements(f), + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.iter().collect::>(); - expr_vec.extend(partition_by); - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - expr_vec + args, + partition_by, + order_by, + .. + }) => { + (args, partition_by, order_by).apply_ref_elements(f) } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![expr.as_ref()]; - expr_vec.extend(list); - expr_vec + (expr, list).apply_ref_elements(f) } - }; - - children.into_iter().apply_until_stop(f) + } } /// Maps each child of `self` using the provided closure `f`. @@ -148,137 +120,103 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_) => Transformed::no(self), - Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))), + Expr::Unnest(Unnest { expr, .. }) => expr + .map_elements(f)? + .update_data(|expr| Expr::Unnest(Unnest { expr })), Expr::Alias(Alias { expr, relation, name, - }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), + }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => transform_box(expr, &mut f)?.update_data(|be| { + }) => expr.map_elements(f)?.update_data(|be| { Expr::InSubquery(InSubquery::new(be, subquery, negated)) }), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - map_until_stop_and_collect!( - transform_box(left, &mut f), - right, - transform_box(right, &mut f) - )? + Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right) + .map_elements(f)? .update_data(|(new_left, new_right)| { Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) - }) - } + }), Expr::Like(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) + } Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), - Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), - Expr::IsNotNull(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) - } - Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), - Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), - Expr::IsFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsFalse) - } - Expr::IsUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) - } - Expr::IsNotTrue(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) - } - Expr::IsNotFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) } + Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not), + Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull), + Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse), + Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown), + Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue), + Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse), Expr::IsNotUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) - } - Expr::Negative(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::Negative) + expr.map_elements(f)?.update_data(Expr::IsNotUnknown) } + Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative), Expr::Between(Between { expr, negated, low, high, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - low, - transform_box(low, &mut f), - high, - transform_box(high, &mut f) - )? - .update_data(|(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) - }), + }) => (expr, low, high).map_elements(f)?.update_data( + |(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }, + ), Expr::Case(Case { expr, when_then_expr, else_expr, - }) => map_until_stop_and_collect!( - transform_option_box(expr, &mut f), - when_then_expr, - when_then_expr - .into_iter() - .map_until_stop_and_collect(|(when, then)| { - map_until_stop_and_collect!( - transform_box(when, &mut f), - then, - transform_box(then, &mut f) - ) - }), - else_expr, - transform_option_box(else_expr, &mut f) - )? - .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) - }), - Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + }) => (expr, when_then_expr, else_expr) + .map_elements(f)? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + Expr::TryCast(TryCast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::ScalarFunction(ScalarFunction { func, args }) => { - transform_vec(args, &mut f)?.map_data(|new_args| { + args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( func, new_args, ))) @@ -291,22 +229,17 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - partition_by, - transform_vec(partition_by, &mut f), - order_by, - transform_sort_vec(order_by, &mut f) - )? - .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }), + }) => (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }, + ), Expr::AggregateFunction(AggregateFunction { args, func, @@ -314,31 +247,27 @@ impl TreeNode for Expr { filter, order_by, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - filter, - transform_option_box(filter, &mut f), - order_by, - transform_sort_option_vec(order_by, &mut f) - )? - .map_data(|(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - func, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - })?, + }) => (args, filter, order_by).map_elements(f)?.map_data( + |(new_args, new_filter, new_order_by)| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + new_args, + distinct, + new_filter, + new_order_by, + null_treatment, + ))) + }, + )?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Rollup(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), - GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Cube(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs - .into_iter() - .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .map_elements(f)? .update_data(|new_lists_of_exprs| { Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) }), @@ -347,70 +276,11 @@ impl TreeNode for Expr { expr, list, negated, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - list, - transform_vec(list, &mut f) - )? - .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) - }), + }) => (expr, list) + .map_elements(f)? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), }) } } - -/// Transforms a boxed expression by applying the provided closure `f`. -fn transform_box Result>>( - be: Box, - f: &mut F, -) -> Result>> { - Ok(f(*be)?.update_data(Box::new)) -} - -/// Transforms an optional boxed expression by applying the provided closure `f`. -fn transform_option_box Result>>( - obe: Option>, - f: &mut F, -) -> Result>>> { - obe.map_or(Ok(Transformed::no(None)), |be| { - Ok(transform_box(be, f)?.update_data(Some)) - }) -} - -/// &mut transform a Option<`Vec` of `Expr`s> -pub fn transform_option_vec Result>>( - ove: Option>, - f: &mut F, -) -> Result>>> { - ove.map_or(Ok(Transformed::no(None)), |ve| { - Ok(transform_vec(ve, f)?.update_data(Some)) - }) -} - -/// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>>( - ve: Vec, - f: &mut F, -) -> Result>> { - ve.into_iter().map_until_stop_and_collect(f) -} - -/// Transforms an optional vector of sort expressions by applying the provided closure `f`. -pub fn transform_sort_option_vec Result>>( - sorts_option: Option>, - f: &mut F, -) -> Result>>> { - sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { - Ok(transform_sort_vec(sorts, f)?.update_data(Some)) - }) -} - -/// Transforms an vector of sort expressions by applying the provided closure `f`. -pub fn transform_sort_vec Result>>( - sorts: Vec, - f: &mut F, -) -> Result>> { - sorts.into_iter().map_until_stop_and_collect(|s| { - Ok(f(s.expr)?.update_data(|e| Sort { expr: e, ..s })) - }) -} diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b659e477f67e2..1519c54dbf68a 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -39,7 +39,7 @@ use datafusion_expr::{ use crate::optimize_projections::required_indices::RequiredIndicies; use crate::utils::NamePreserver; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; /// Optimizer rule to prune unnecessary columns from intermediate schemas @@ -484,7 +484,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result Result /// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { - let sort_exprs = transform_sort_vec(exprs, &mut |expr| { - expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { - col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) - } else { - Ok(Transformed::no(expr)) - } + let sort_exprs = exprs + .map_elements(&mut |expr: Expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) }) - }) - .data()?; + .data()?; Ok(sort_exprs) } From 963e8af048811d57657584cbbdb460942ea426de Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 20 Nov 2024 16:27:25 +0800 Subject: [PATCH 35/45] Chunk based iteration in `accumulate_indices` (#13451) * filter chunk Signed-off-by: Jay Zhan * fmt Signed-off-by: Jay Zhan * acc Signed-off-by: Jay Zhan * BitIndexIterator Signed-off-by: Jay Zhan * cleanup Signed-off-by: Jay Zhan * count group Signed-off-by: Jay Zhan * add benches Signed-off-by: jayzhan211 * revert to fixed chunk based method instead of iterating set_indices Signed-off-by: jayzhan211 * revert count change Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * taplo format Signed-off-by: jayzhan211 --------- Signed-off-by: Jay Zhan Signed-off-by: jayzhan211 --- .../functions-aggregate-common/Cargo.toml | 7 ++ .../benches/accumulate.rs | 115 ++++++++++++++++++ .../groups_accumulator/accumulate.rs | 90 ++++++++++---- 3 files changed, 190 insertions(+), 22 deletions(-) create mode 100644 datafusion/functions-aggregate-common/benches/accumulate.rs diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml index 9b299c1a11d7d..664746808fb48 100644 --- a/datafusion/functions-aggregate-common/Cargo.toml +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -43,3 +43,10 @@ datafusion-common = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } rand = { workspace = true } + +[dev-dependencies] +criterion = "0.5" + +[[bench]] +harness = false +name = "accumulate" diff --git a/datafusion/functions-aggregate-common/benches/accumulate.rs b/datafusion/functions-aggregate-common/benches/accumulate.rs new file mode 100644 index 0000000000000..f422f8a2a7bfd --- /dev/null +++ b/datafusion/functions-aggregate-common/benches/accumulate.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, Int64Array}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; + +fn generate_group_indices(len: usize) -> Vec { + (0..len).collect() +} + +fn generate_values(len: usize, has_null: bool) -> ArrayRef { + if has_null { + let values = (0..len) + .map(|i| if i % 7 == 0 { None } else { Some(i as i64) }) + .collect::>(); + Arc::new(Int64Array::from(values)) + } else { + let values = (0..len).map(|i| Some(i as i64)).collect::>(); + Arc::new(Int64Array::from(values)) + } +} + +fn generate_filter(len: usize) -> Option { + let values = (0..len) + .map(|i| { + if i % 7 == 0 { + None + } else if i % 5 == 0 { + Some(false) + } else { + Some(true) + } + }) + .collect::>(); + Some(BooleanArray::from(values)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let len = 500_000; + let group_indices = generate_group_indices(len); + let rows_count = group_indices.len(); + let values = generate_values(len, true); + let opt_filter = generate_filter(len); + let mut counts: Vec = vec![0; rows_count]; + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + + c.bench_function("Handle both nulls and filter", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); + + c.bench_function("Handle nulls only", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + None, + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); + + let values = generate_values(len, false); + c.bench_function("Handle filter only", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 3efd348937ed4..ac4d0e75535e4 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -395,19 +395,41 @@ pub fn accumulate_indices( } } (None, Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - // The performance with a filter could be improved by - // iterating over the filter in chunks, rather than a single - // iterator. TODO file a ticket - let iter = group_indices.iter().zip(filter.iter()); - for (&group_index, filter_value) in iter { - if let Some(true) = filter_value { - index_fn(group_index) - } - } + debug_assert_eq!(filter.len(), group_indices.len()); + let group_indices_chunks = group_indices.chunks_exact(64); + let bit_chunks = filter.values().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }, + ); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + index_fn(group_index) + } + }); } (Some(valids), None) => { - assert_eq!(valids.len(), group_indices.len()); + debug_assert_eq!(valids.len(), group_indices.len()); // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum // iterate over in chunks of 64 bits for more efficient null checking let group_indices_chunks = group_indices.chunks_exact(64); @@ -444,20 +466,44 @@ pub fn accumulate_indices( } (Some(valids), Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - assert_eq!(valids.len(), group_indices.len()); - // The performance with a filter could likely be improved by - // iterating over the filter in chunks, rather than using - // iterators. TODO file a ticket - filter + debug_assert_eq!(filter.len(), group_indices.len()); + debug_assert_eq!(valids.len(), group_indices.len()); + + let group_indices_chunks = group_indices.chunks_exact(64); + let valid_bit_chunks = valids.inner().bit_chunks(); + let filter_bit_chunks = filter.values().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks + .zip(valid_bit_chunks.iter()) + .zip(filter_bit_chunks.iter()) + .for_each(|((group_index_chunk, valid_mask), filter_mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (valid_mask & filter_mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }); + + // handle any remaining bits (after the initial 64) + let remainder_valid_bits = valid_bit_chunks.remainder_bits(); + let remainder_filter_bits = filter_bit_chunks.remainder_bits(); + group_indices_remainder .iter() - .zip(group_indices.iter()) - .zip(valids.iter()) - .for_each(|((filter_value, &group_index), is_valid)| { - if let (Some(true), true) = (filter_value, is_valid) { + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = + remainder_valid_bits & remainder_filter_bits & (1 << i) != 0; + if is_valid { index_fn(group_index) } - }) + }); } } } From 8ce4da637cc7564f2d03369c13c788e106858f27 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Nov 2024 06:58:14 -0500 Subject: [PATCH 36/45] Clarify documentation about use of tokio tasks (#13474) --- datafusion/core/src/lib.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index b58ef66d4cd2b..d049e774d7c6d 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -475,10 +475,12 @@ //! //! The number of cores used is determined by the `target_partitions` //! configuration setting, which defaults to the number of CPU cores. -//! During execution, DataFusion creates this many distinct `async` [`Stream`]s and -//! this many distinct [Tokio] [`task`]s, which drive the `Stream`s -//! using threads managed by the `Runtime`. Many DataFusion `Stream`s perform -//! CPU intensive processing. +//! While preparing for execution, DataFusion tries to create this many distinct +//! `async` [`Stream`]s for each `ExecutionPlan`. +//! The `Stream`s for certain `ExecutionPlans`, such as as [`RepartitionExec`] +//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by +//! threads managed by the `Runtime`. +//! Many DataFusion `Stream`s perform CPU intensive processing. //! //! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s //! to perform network I/O using standard Rust `async` during execution. @@ -582,6 +584,8 @@ //! [`Runtime`]: tokio::runtime::Runtime //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ +//! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec +//! [`CoalescePartitionsExec`]: physical_plan::coalesce_partitions::CoalescePartitionsExec //! //! ## State Management and Configuration //! From ecc04d4af85a29111a1598e615350fea84e60fcb Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Wed, 20 Nov 2024 06:59:21 -0500 Subject: [PATCH 37/45] feat: Support faster multi-column grouping ( `GroupColumn`) for `Date/Time/Timestamp` types (#13457) * feat: Add `GroupColumn` for `Date/Time/Timestamp` * Add tests --- .../src/aggregates/group_values/mod.rs | 28 +++ .../group_values/multi_group_by/mod.rs | 42 +++- .../sqllogictest/test_files/group_by.slt | 196 ++++++++++++++++++ 3 files changed, 263 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index a816203b68124..ae528daad53c5 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -18,7 +18,13 @@ //! [`GroupValues`] trait for storing and interning group keys use arrow::record_batch::RecordBatch; +use arrow_array::types::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; use arrow_array::{downcast_primitive, ArrayRef}; +use arrow_schema::TimeUnit; use arrow_schema::{DataType, SchemaRef}; use datafusion_common::Result; @@ -142,6 +148,28 @@ pub(crate) fn new_group_values( } match d { + DataType::Date32 => { + downcast_helper!(Date32Type, d); + } + DataType::Date64 => { + downcast_helper!(Date64Type, d); + } + DataType::Time32(t) => match t { + TimeUnit::Second => downcast_helper!(Time32SecondType, d), + TimeUnit::Millisecond => downcast_helper!(Time32MillisecondType, d), + _ => {} + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => downcast_helper!(Time64MicrosecondType, d), + TimeUnit::Nanosecond => downcast_helper!(Time64NanosecondType, d), + _ => {} + }, + DataType::Timestamp(t, _) => match t { + TimeUnit::Second => downcast_helper!(TimestampSecondType, d), + TimeUnit::Millisecond => downcast_helper!(TimestampMillisecondType, d), + TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), + TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), + }, DataType::Utf8 => { return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 83b0f9d773693..10b00cf74fdb7 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -32,12 +32,14 @@ use ahash::RandomState; use arrow::compute::cast; use arrow::datatypes::{ BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, StringViewType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + Int32Type, Int64Type, Int8Type, StringViewType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow::record_batch::RecordBatch; use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; @@ -913,6 +915,38 @@ impl GroupValues for GroupValuesColumn { } &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), + &DataType::Time32(t) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, Time32SecondType) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, Time32MillisecondType) + } + _ => {} + }, + &DataType::Time64(t) => match t { + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, Time64MicrosecondType) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, Time64NanosecondType) + } + _ => {} + }, + &DataType::Timestamp(t, _) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, TimestampSecondType) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, TimestampMillisecondType) + } + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, TimestampMicrosecondType) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, TimestampNanosecondType) + } + }, &DataType::Utf8 => { let b = ByteGroupValueBuilder::::new(OutputType::Utf8); v.push(Box::new(b) as _) @@ -1125,6 +1159,8 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::LargeBinary | DataType::Date32 | DataType::Date64 + | DataType::Time32(_) + | DataType::Timestamp(_, _) | DataType::Utf8View | DataType::BinaryView ) diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 391f84836871c..f74e1006f7f67 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5272,6 +5272,201 @@ drop view t statement ok drop table source; +# Test multi group by int + Date32 +statement ok +create table source as values +(1, '2020-01-01'), +(1, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-03'), +(3, '2020-01-04'), +(3, '2020-01-04'), +(2, '2020-01-03'), +(null, null), +(null, '2020-01-01'), +(null, null), +(null, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-02'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Date32') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 2020-01-01 2 +1 NULL 1 +2 2020-01-02 3 +2 2020-01-03 2 +3 2020-01-04 2 +NULL 2020-01-01 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Date64 +statement ok +create table source as values +(1, '2020-01-01'), +(1, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-03'), +(3, '2020-01-04'), +(3, '2020-01-04'), +(2, '2020-01-03'), +(null, null), +(null, '2020-01-01'), +(null, null), +(null, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-02'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Date64') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 2020-01-01T00:00:00 2 +1 NULL 1 +2 2020-01-02T00:00:00 3 +2 2020-01-03T00:00:00 2 +3 2020-01-04T00:00:00 2 +NULL 2020-01-01T00:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Time32 +statement ok +create table source as values +(1, '12:34:56'), +(1, '12:34:56'), +(2, '13:00:00'), +(2, '14:15:00'), +(3, '23:59:59'), +(3, '23:59:59'), +(2, '14:15:00'), +(null, null), +(null, '12:00:00'), +(null, null), +(null, '12:00:00'), +(2, '13:00:00'), +(2, '13:00:00'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Time32(Second)') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 12:34:56 2 +1 NULL 1 +2 13:00:00 3 +2 14:15:00 2 +3 23:59:59 2 +NULL 12:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Time64 +statement ok +create table source as values +(1, '12:34:56.123456'), +(1, '12:34:56.123456'), +(2, '13:00:00.000001'), +(2, '14:15:00.999999'), +(3, '23:59:59.500000'), +(3, '23:59:59.500000'), +(2, '14:15:00.999999'), +(null, null), +(null, '12:00:00.000000'), +(null, null), +(null, '12:00:00.000000'), +(2, '13:00:00.000001'), +(2, '13:00:00.000001'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Time64(Microsecond)') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 12:34:56.123456 2 +1 NULL 1 +2 13:00:00.000001 3 +2 14:15:00.999999 2 +3 23:59:59.500 2 +NULL 12:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Timestamp +statement ok +create table source as values +(1, '2020-01-01 12:34:56'), +(1, '2020-01-01 12:34:56'), +(2, '2020-01-02 13:00:00'), +(2, '2020-01-03 14:15:00'), +(3, '2020-01-04 23:59:59'), +(3, '2020-01-04 23:59:59'), +(2, '2020-01-03 14:15:00'), +(null, null), +(null, '2020-01-01 12:00:00'), +(null, null), +(null, '2020-01-01 12:00:00'), +(2, '2020-01-02 13:00:00'), +(2, '2020-01-02 13:00:00'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Timestamp(Nanosecond, None)') as b from source; + +query IPI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 2020-01-01T12:34:56 2 +1 NULL 1 +2 2020-01-02T13:00:00 3 +2 2020-01-03T14:15:00 2 +3 2020-01-04T23:59:59 2 +NULL 2020-01-01T12:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + # Test whether min, max accumulator produces NaN result when input is NaN. # See https://github.com/apache/datafusion/issues/13415 for rationale statement ok @@ -5287,3 +5482,4 @@ query RR SELECT max(input_table.x), min(input_table.x) from input_table GROUP BY input_table."row"; ---- NaN NaN + From a2f4878ff0b972038abeea335b0cf72547b01633 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 20 Nov 2024 17:56:38 +0100 Subject: [PATCH 38/45] Fix DataFusionError use in schema_err macro (#13488) Use declaring-crate-relative references so that macro use place does not need to import symbols it doesn't use. --- datafusion/common/src/column.rs | 2 +- datafusion/common/src/error.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index d855198fa7c6b..c47ed28159064 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -21,7 +21,7 @@ use arrow_schema::{Field, FieldRef}; use crate::error::_schema_err; use crate::utils::{parse_identifiers_normalized, quote_identifier}; -use crate::{DFSchema, DataFusionError, Result, SchemaError, TableReference}; +use crate::{DFSchema, Result, SchemaError, TableReference}; use std::collections::HashSet; use std::convert::Infallible; use std::fmt; diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 05988d6c6da4c..4fac7298c455a 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -598,9 +598,9 @@ macro_rules! arrow_err { #[macro_export] macro_rules! schema_datafusion_err { ($ERR:expr) => { - DataFusionError::SchemaError( + $crate::error::DataFusionError::SchemaError( $ERR, - Box::new(Some(DataFusionError::get_back_trace())), + Box::new(Some($crate::error::DataFusionError::get_back_trace())), ) }; } @@ -609,9 +609,9 @@ macro_rules! schema_datafusion_err { #[macro_export] macro_rules! schema_err { ($ERR:expr) => { - Err(DataFusionError::SchemaError( + Err($crate::error::DataFusionError::SchemaError( $ERR, - Box::new(Some(DataFusionError::get_back_trace())), + Box::new(Some($crate::error::DataFusionError::get_back_trace())), )) }; } From 5ee524ec70b1f10a078caca62954ce37b2dc3cc6 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Wed, 20 Nov 2024 18:11:18 +0100 Subject: [PATCH 39/45] feat(substrait): replace SessionContext with a trait (#13343) * feat(substrait): replace SessionContext with SessionState * feat(substrait): add logical plan context * chore(substrait): add apache header * docs: fix code in docs * docs(substrait): rename and document context * chore(substrait): context -> state * chore: fmt --- .../core/src/execution/session_state.rs | 4 +- datafusion/substrait/Cargo.toml | 1 + datafusion/substrait/src/lib.rs | 4 +- .../substrait/src/logical_plan/consumer.rs | 286 ++++++++++-------- datafusion/substrait/src/logical_plan/mod.rs | 1 + .../substrait/src/logical_plan/producer.rs | 196 ++++++------ .../substrait/src/logical_plan/state.rs | 63 ++++ datafusion/substrait/src/serializer.rs | 2 +- .../tests/cases/consumer_integration.rs | 2 +- .../substrait/tests/cases/emit_kind_tests.rs | 12 +- .../substrait/tests/cases/function_test.rs | 2 +- .../substrait/tests/cases/logical_plans.rs | 6 +- .../tests/cases/roundtrip_logical_plan.rs | 40 +-- datafusion/substrait/tests/cases/serialize.rs | 12 +- .../tests/cases/substrait_validations.rs | 10 +- 15 files changed, 379 insertions(+), 262 deletions(-) create mode 100644 datafusion/substrait/src/logical_plan/state.rs diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9fc081dd53298..e99cf82223815 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -296,7 +296,9 @@ impl SessionState { .resolve(&catalog.default_catalog, &catalog.default_schema) } - pub(crate) fn schema_for_ref( + /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if it + /// esists. + pub fn schema_for_ref( &self, table_ref: impl Into, ) -> datafusion_common::Result> { diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 192fe26d6cef6..61cdf3e91e3c1 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -34,6 +34,7 @@ workspace = true [dependencies] arrow-buffer = { workspace = true } async-recursion = "1.0" +async-trait = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true, default-features = true } itertools = { workspace = true } diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index a6f7c033f9d0b..1389cac75b99c 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -64,10 +64,10 @@ //! let plan = df.into_optimized_plan()?; //! //! // Convert the plan into a substrait (protobuf) Plan -//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, &ctx)?; +//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, &ctx.state())?; //! //! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan -//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?; +//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx.state(), &substrait_plan).await?; //! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?; //! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); //! # Ok(()) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1cce228527ecf..77e9eb81f546f 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -26,7 +26,7 @@ use datafusion::common::{ not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; -use datafusion::execution::FunctionRegistry; +use datafusion::datasource::provider_as_source; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ @@ -56,7 +56,6 @@ use crate::variation_const::{ use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::dataframe::DataFrame; use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -66,9 +65,7 @@ use datafusion::logical_expr::{ use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ - error::Result, - logical_expr::utils::split_conjunction, - prelude::{Column, SessionContext}, + error::Result, logical_expr::utils::split_conjunction, prelude::Column, scalar::ScalarValue, }; use std::collections::HashSet; @@ -102,6 +99,8 @@ use substrait::proto::{ }; use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; +use super::state::SubstraitPlanningState; + // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone // results in correct points on the timeline, and we pick UTC as a reasonable default. @@ -203,15 +202,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( async fn union_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &rels[0], extensions).await?, + from_substrait_rel(state, &rels[0], extensions).await?, )); for input in &rels[1..] { - let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + let rel_plan = from_substrait_rel(state, input, extensions).await?; union_builder = if is_all { union_builder?.union(rel_plan) @@ -224,16 +223,16 @@ async fn union_rels( async fn intersect_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::intersect( rel, - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, is_all, )? } @@ -243,16 +242,16 @@ async fn intersect_rels( async fn except_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::except( rel, - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, is_all, )? } @@ -262,7 +261,7 @@ async fn except_rels( /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, plan: &Plan, ) -> Result { // Register function extension @@ -277,10 +276,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &extensions).await?) + Ok(from_substrait_rel(state, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; + let plan = from_substrait_rel(state, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -341,7 +340,7 @@ pub struct ExprContainer { /// between systems. This is often useful for scenarios like pushdown where filter /// expressions need to be sent to remote systems. pub async fn from_substrait_extended_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extended_expr: &ExtendedExpression, ) -> Result { // Register function extension @@ -370,7 +369,7 @@ pub async fn from_substrait_extended_expr( } }?; let expr = - from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?; + from_substrait_rex(state, scalar_expr, &input_schema, &extensions).await?; let (output_type, expected_nullability) = expr.data_type_and_nullable(&input_schema)?; let output_field = Field::new("", output_type, expected_nullability); @@ -561,7 +560,7 @@ fn make_renamed_schema( #[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, rel: &Rel, extensions: &Extensions, ) -> Result { @@ -569,7 +568,7 @@ pub async fn from_substrait_rel( Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { let mut input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let original_schema = input.schema().clone(); @@ -587,9 +586,13 @@ pub async fn from_substrait_rel( let mut explicit_exprs: Vec = vec![]; for expr in &p.expressions { - let e = - from_substrait_rex(ctx, expr, input.clone().schema(), extensions) - .await?; + let e = from_substrait_rex( + state, + expr, + input.clone().schema(), + extensions, + ) + .await?; // if the expression is WindowFunction, wrap in a Window relation if let Expr::WindowFunction(_) = &e { // Adding the same expression here and in the project below @@ -617,11 +620,11 @@ pub async fn from_substrait_rel( Some(RelType::Filter(filter)) => { if let Some(input) = filter.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(ctx, condition, input.schema(), extensions) + from_substrait_rex(state, condition, input.schema(), extensions) .await?; input.filter(expr)?.build() } else { @@ -634,7 +637,7 @@ pub async fn from_substrait_rel( Some(RelType::Fetch(fetch)) => { if let Some(input) = fetch.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let offset = fetch.offset as usize; // -1 means that ALL records should be returned @@ -651,10 +654,10 @@ pub async fn from_substrait_rel( Some(RelType::Sort(sort)) => { if let Some(input) = sort.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let sorts = - from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + from_substrait_sorts(state, &sort.sorts, input.schema(), extensions) .await?; input.sort(sorts)?.build() } else { @@ -664,13 +667,13 @@ pub async fn from_substrait_rel( Some(RelType::Aggregate(agg)) => { if let Some(input) = agg.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let mut ref_group_exprs = vec![]; for e in &agg.grouping_expressions { let x = - from_substrait_rex(ctx, e, input.schema(), extensions).await?; + from_substrait_rex(state, e, input.schema(), extensions).await?; ref_group_exprs.push(x); } @@ -681,7 +684,7 @@ pub async fn from_substrait_rel( 1 => { group_exprs.extend_from_slice( &from_substrait_grouping( - ctx, + state, &agg.groupings[0], &ref_group_exprs, input.schema(), @@ -694,7 +697,7 @@ pub async fn from_substrait_rel( let mut grouping_sets = vec![]; for grouping in &agg.groupings { let grouping_set = from_substrait_grouping( - ctx, + state, grouping, &ref_group_exprs, input.schema(), @@ -716,7 +719,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(ctx, fil, input.schema(), extensions) + from_substrait_rex(state, fil, input.schema(), extensions) .await?, )), None => None, @@ -739,7 +742,7 @@ pub async fn from_substrait_rel( let order_by = if !f.sorts.is_empty() { Some( from_substrait_sorts( - ctx, + state, &f.sorts, input.schema(), extensions, @@ -751,7 +754,7 @@ pub async fn from_substrait_rel( }; from_substrait_agg_func( - ctx, + state, f, input.schema(), extensions, @@ -780,10 +783,12 @@ pub async fn from_substrait_rel( } let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, join.left.as_ref().unwrap(), extensions) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, join.right.as_ref().unwrap(), extensions) + .await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; @@ -796,7 +801,7 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + let on = from_substrait_rex(state, expr, &in_join_schema, extensions) .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. @@ -831,26 +836,44 @@ pub async fn from_substrait_rel( } Some(RelType::Cross(cross)) => { let left = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, cross.left.as_ref().unwrap(), extensions) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + from_substrait_rel(state, cross.right.as_ref().unwrap(), extensions) .await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } Some(RelType::Read(read)) => { - fn read_with_schema( - df: DataFrame, + async fn read_with_schema( + state: &dyn SubstraitPlanningState, + table_ref: TableReference, schema: DFSchema, projection: &Option, ) -> Result { - ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; + let schema = schema.replace_qualifier(table_ref.clone()); + + let plan = { + let provider = match state.table(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; + + LogicalPlanBuilder::scan( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + )? + .build()? + }; + + ensure_schema_compatability(plan.schema(), schema.clone())?; let schema = apply_masking(schema, projection)?; - apply_projection(df, schema) + apply_projection(plan, schema) } let named_struct = read.base_schema.as_ref().ok_or_else(|| { @@ -879,12 +902,13 @@ pub async fn from_substrait_rel( }, }; - let t = ctx.table(table_reference.clone()).await?; - - let substrait_schema = - substrait_schema.replace_qualifier(table_reference); - - read_with_schema(t, substrait_schema, &read.projection) + read_with_schema( + state, + table_reference, + substrait_schema, + &read.projection, + ) + .await } Some(ReadType::VirtualTable(vt)) => { if vt.values.is_empty() { @@ -960,12 +984,14 @@ pub async fn from_substrait_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference.clone()).await?; - - let substrait_schema = - substrait_schema.replace_qualifier(table_reference); - read_with_schema(t, substrait_schema, &read.projection) + read_with_schema( + state, + table_reference, + substrait_schema, + &read.projection, + ) + .await } _ => { not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) @@ -979,31 +1005,31 @@ pub async fn from_substrait_rel( } else { match set_op { set_rel::SetOp::UnionAll => { - union_rels(&set.inputs, ctx, extensions, true).await + union_rels(&set.inputs, state, extensions, true).await } set_rel::SetOp::UnionDistinct => { - union_rels(&set.inputs, ctx, extensions, false).await + union_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::IntersectionPrimary => { LogicalPlanBuilder::intersect( - from_substrait_rel(ctx, &set.inputs[0], extensions) + from_substrait_rel(state, &set.inputs[0], extensions) .await?, - union_rels(&set.inputs[1..], ctx, extensions, true) + union_rels(&set.inputs[1..], state, extensions, true) .await?, false, ) } set_rel::SetOp::IntersectionMultiset => { - intersect_rels(&set.inputs, ctx, extensions, false).await + intersect_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::IntersectionMultisetAll => { - intersect_rels(&set.inputs, ctx, extensions, true).await + intersect_rels(&set.inputs, state, extensions, true).await } set_rel::SetOp::MinusPrimary => { - except_rels(&set.inputs, ctx, extensions, false).await + except_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::MinusPrimaryAll => { - except_rels(&set.inputs, ctx, extensions, true).await + except_rels(&set.inputs, state, extensions, true).await } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), } @@ -1015,8 +1041,7 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -1025,8 +1050,7 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { @@ -1034,7 +1058,7 @@ pub async fn from_substrait_rel( "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" ); }; - let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; + let input_plan = from_substrait_rel(state, input_rel, extensions).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -1043,13 +1067,12 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let mut inputs = Vec::with_capacity(extension.inputs.len()); for input in &extension.inputs { - let input_plan = from_substrait_rel(ctx, input, extensions).await?; + let input_plan = from_substrait_rel(state, input, extensions).await?; inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; @@ -1059,7 +1082,7 @@ pub async fn from_substrait_rel( let Some(input) = exchange.input.as_ref() else { return substrait_err!("Unexpected empty input in ExchangeRel"); }; - let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + let input = Arc::new(from_substrait_rel(state, input, extensions).await?); let Some(exchange_kind) = &exchange.exchange_kind else { return substrait_err!("Unexpected empty input in ExchangeRel"); @@ -1237,7 +1260,7 @@ impl NameTracker { /// DataFusion schema may have MORE fields, but not the other way around. /// 2. All fields are compatible. See [`ensure_field_compatability`] for details fn ensure_schema_compatability( - table_schema: DFSchema, + table_schema: &DFSchema, substrait_schema: DFSchema, ) -> Result<()> { substrait_schema @@ -1253,16 +1276,19 @@ fn ensure_schema_compatability( /// This function returns a DataFrame with fields adjusted if necessary in the event that the /// Substrait schema is a subset of the DataFusion schema. -fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { - let df_schema = table.schema().to_owned(); - - let t = table.into_unoptimized_plan(); +fn apply_projection( + plan: LogicalPlan, + substrait_schema: DFSchema, +) -> Result { + let df_schema = plan.schema(); if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(t); + return Ok(plan); } - match t { + let df_schema = df_schema.to_owned(); + + match plan { LogicalPlan::TableScan(mut scan) => { let column_indices: Vec = substrait_schema .strip_qualifiers() @@ -1389,7 +1415,7 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &Extensions, @@ -1397,7 +1423,7 @@ pub async fn from_substrait_sorts( let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = - from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema, extensions) .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { @@ -1439,14 +1465,15 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &Vec, input_schema: &DFSchema, extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; + let expression = + from_substrait_rex(state, expr, input_schema, extensions).await?; expressions.push(expression); } Ok(expressions) @@ -1454,7 +1481,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substrait_func_args( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, arguments: &Vec, input_schema: &DFSchema, extensions: &Extensions, @@ -1463,7 +1490,7 @@ pub async fn from_substrait_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await + from_substrait_rex(state, e, input_schema, extensions).await } _ => not_impl_err!("Function argument non-Value type not supported"), }; @@ -1474,7 +1501,7 @@ pub async fn from_substrait_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, f: &AggregateFunction, input_schema: &DFSchema, extensions: &Extensions, @@ -1483,7 +1510,7 @@ pub async fn from_substrait_agg_func( distinct: bool, ) -> Result> { let args = - from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; + from_substrait_func_args(state, &f.arguments, input_schema, extensions).await?; let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( @@ -1494,7 +1521,7 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. - if let Ok(fun) = ctx.udaf(function_name) { + if let Ok(fun) = state.udaf(function_name) { // deal with situation that count(*) got no arguments let args = if fun.name() == "count" && args.is_empty() { vec![Expr::Literal(ScalarValue::Int64(Some(1)))] @@ -1517,7 +1544,7 @@ pub async fn from_substrait_agg_func( /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, e: &Expression, input_schema: &DFSchema, extensions: &Extensions, @@ -1528,11 +1555,11 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Expr::InList(InList { expr: Box::new( - from_substrait_rex(ctx, substrait_expr, input_schema, extensions) + from_substrait_rex(state, substrait_expr, input_schema, extensions) .await?, ), list: from_substrait_rex_vec( - ctx, + state, substrait_list, input_schema, extensions, @@ -1555,7 +1582,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( - ctx, + state, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -1568,7 +1595,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( - ctx, + state, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -1577,7 +1604,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( - ctx, + state, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -1589,7 +1616,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(ctx, e, input_schema, extensions).await?, + from_substrait_rex(state, e, input_schema, extensions).await?, )), None => None, }; @@ -1609,12 +1636,12 @@ pub async fn from_substrait_rex( let fn_name = substrait_fun_name(fn_name); let args = - from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) + from_substrait_func_args(state, &f.arguments, input_schema, extensions) .await?; // try to first match the requested function into registered udfs, then built-in ops // and finally built-in expressions - if let Some(func) = ctx.state().scalar_functions().get(fn_name) { + if let Ok(func) = state.udf(fn_name) { Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( func.to_owned(), args, @@ -1644,7 +1671,7 @@ pub async fn from_substrait_rex( Ok(combined_expr) } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(ctx, f, input_schema, extensions).await + builder.build(state, f, input_schema, extensions).await } else { not_impl_err!("Unsupported function name: {fn_name:?}") } @@ -1657,7 +1684,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Expr::Cast(Cast::new( Box::new( from_substrait_rex( - ctx, + state, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -1679,9 +1706,9 @@ pub async fn from_substrait_rex( let fn_name = substrait_fun_name(fn_name); // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = ctx.udwf(fn_name) { + let fun = if let Ok(udwf) = state.udwf(fn_name) { Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = ctx.udaf(fn_name) { + } else if let Ok(udaf) = state.udaf(fn_name) { Ok(WindowFunctionDefinition::AggregateUDF(udaf)) } else { not_impl_err!( @@ -1692,7 +1719,7 @@ pub async fn from_substrait_rex( }?; let order_by = - from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + from_substrait_sorts(state, &window.sorts, input_schema, extensions) .await?; let bound_units = @@ -1715,14 +1742,14 @@ pub async fn from_substrait_rex( Ok(Expr::WindowFunction(expr::WindowFunction { fun, args: from_substrait_func_args( - ctx, + state, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( - ctx, + state, &window.partitions, input_schema, extensions, @@ -1747,13 +1774,13 @@ pub async fn from_substrait_rex( let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { let haystack_expr = - from_substrait_rel(ctx, haystack_expr, extensions) + from_substrait_rel(state, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Expr::InSubquery(InSubquery { expr: Box::new( from_substrait_rex( - ctx, + state, needle_expr, input_schema, extensions, @@ -1773,7 +1800,7 @@ pub async fn from_substrait_rex( } SubqueryType::Scalar(query) => { let plan = from_substrait_rel( - ctx, + state, &(query.input.clone()).unwrap_or_default(), extensions, ) @@ -1790,7 +1817,7 @@ pub async fn from_substrait_rex( PredicateOp::Exists => { let relation = &predicate.tuples; let plan = from_substrait_rel( - ctx, + state, &relation.clone().unwrap_or_default(), extensions, ) @@ -2772,7 +2799,7 @@ fn from_substrait_null( #[allow(deprecated)] async fn from_substrait_grouping( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, grouping: &Grouping, expressions: &[Expr], input_schema: &DFSchemaRef, @@ -2781,7 +2808,7 @@ async fn from_substrait_grouping( let mut group_exprs = vec![]; if !grouping.grouping_expressions.is_empty() { for e in &grouping.grouping_expressions { - let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?; + let expr = from_substrait_rex(state, e, input_schema, extensions).await?; group_exprs.push(expr); } return Ok(group_exprs); @@ -2834,23 +2861,29 @@ impl BuiltinExprBuilder { pub async fn build( self, - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, f: &ScalarFunction, input_schema: &DFSchema, extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { "like" => { - Self::build_like_expr(ctx, false, f, input_schema, extensions).await + Self::build_like_expr(state, false, f, input_schema, extensions).await } "ilike" => { - Self::build_like_expr(ctx, true, f, input_schema, extensions).await + Self::build_like_expr(state, true, f, input_schema, extensions).await } "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) - .await + Self::build_unary_expr( + state, + &self.expr_name, + f, + input_schema, + extensions, + ) + .await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -2859,7 +2892,7 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, @@ -2872,7 +2905,7 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; let arg = - from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; + from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; let arg = Box::new(arg); let expr = match fn_name { @@ -2893,7 +2926,7 @@ impl BuiltinExprBuilder { } async fn build_like_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, @@ -2908,12 +2941,13 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let expr = - from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; + from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(ctx, pattern_substrait, input_schema, extensions).await?; + from_substrait_rex(state, pattern_substrait, input_schema, extensions) + .await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -2922,9 +2956,13 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let escape_char_expr = - from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await?; + let escape_char_expr = from_substrait_rex( + state, + escape_char_substrait, + input_schema, + extensions, + ) + .await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { diff --git a/datafusion/substrait/src/logical_plan/mod.rs b/datafusion/substrait/src/logical_plan/mod.rs index 6f8b8e493f529..9e2fa9fa49de1 100644 --- a/datafusion/substrait/src/logical_plan/mod.rs +++ b/datafusion/substrait/src/logical_plan/mod.rs @@ -17,3 +17,4 @@ pub mod consumer; pub mod producer; +pub mod state; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4d864e4334ce6..29019dfd74f32 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -29,7 +29,7 @@ use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, logical_expr::{WindowFrame, WindowFrameBound}, - prelude::{JoinType, SessionContext}, + prelude::JoinType, scalar::ScalarValue, }; @@ -100,8 +100,13 @@ use substrait::{ version, }; +use super::state::SubstraitPlanningState; + /// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { +pub fn to_substrait_plan( + plan: &LogicalPlan, + state: &dyn SubstraitPlanningState, +) -> Result> { let mut extensions = Extensions::default(); // Parse relation nodes // Generate PlanRel(s) @@ -113,7 +118,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result Result> { let mut extensions = Extensions::default(); @@ -152,7 +157,7 @@ pub fn to_substrait_extended_expr( .iter() .map(|(expr, field)| { let substrait_expr = to_substrait_rex( - ctx, + state, expr, schema, /*col_ref_offset=*/ 0, @@ -183,7 +188,7 @@ pub fn to_substrait_extended_expr( #[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &mut Extensions, ) -> Result> { match plan { @@ -284,7 +289,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) + .map(|e| to_substrait_rex(state, e, p.input.schema(), 0, extensions)) .collect::>>()?; let emit_kind = create_project_remapping( @@ -300,16 +305,16 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: Some(common), - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), + input: Some(to_substrait_rel(p.input.as_ref(), state, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(filter.input.as_ref(), state, extensions)?; let filter_expr = to_substrait_rex( - ctx, + state, &filter.predicate, filter.input.schema(), 0, @@ -325,7 +330,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return not_impl_err!("Non-literal limit fetch"); }; @@ -344,11 +349,11 @@ pub fn to_substrait_rel( })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -360,9 +365,9 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; let (grouping_expressions, groupings) = to_substrait_groupings( - ctx, + state, &agg.group_expr, agg.input.schema(), extensions, @@ -370,7 +375,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) + .map(|e| { + to_substrait_agg_measure(state, e, agg.input.schema(), extensions) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -386,7 +393,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(plan.as_ref(), state, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -407,8 +414,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; + let left = to_substrait_rel(join.left.as_ref(), state, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), state, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -421,7 +428,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( - ctx, + state, filter, &Arc::new(in_join_schema), 0, @@ -438,7 +445,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( - ctx, + state, &join.on, eq_op, join.left.schema(), @@ -479,13 +486,13 @@ pub fn to_substrait_rel( LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extensions) + to_substrait_rel(alias.input.as_ref(), state, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) + .map(|input| to_substrait_rel(input.as_ref(), state, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -500,7 +507,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(window.input.as_ref(), state, extensions)?; // create a field reference for each input field let mut expressions = (0..window.input.schema().fields().len()) @@ -510,7 +517,7 @@ pub fn to_substrait_rel( // process and add each window function expression for expr in &window.window_expr { expressions.push(to_substrait_rex( - ctx, + state, expr, window.input.schema(), 0, @@ -539,7 +546,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Repartition(repartition) => { - let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(repartition.input.as_ref(), state, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -585,8 +592,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Extension(extension_plan) => { - let extension_bytes = ctx - .state() + let extension_bytes = state .serializer_registry() .serialize_logical_plan(extension_plan.node.as_ref())?; let detail = ProtoAny { @@ -597,7 +603,7 @@ pub fn to_substrait_rel( .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extensions)) + .map(|plan| to_substrait_rel(plan, state, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -687,7 +693,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { } fn to_substrait_join_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -698,10 +704,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; + let l = to_substrait_rex(state, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( - ctx, + state, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -770,7 +776,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { #[allow(deprecated)] pub fn parse_flat_grouping_exprs( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, @@ -780,7 +786,7 @@ pub fn parse_flat_grouping_exprs( let mut grouping_expressions = vec![]; for e in exprs { - let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?; + let rex = to_substrait_rex(state, e, schema, 0, extensions)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); @@ -792,7 +798,7 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, @@ -808,7 +814,7 @@ pub fn to_substrait_groupings( .iter() .map(|set| { parse_flat_grouping_exprs( - ctx, + state, set, schema, extensions, @@ -826,7 +832,7 @@ pub fn to_substrait_groupings( .rev() .map(|set| { parse_flat_grouping_exprs( - ctx, + state, set, schema, extensions, @@ -837,7 +843,7 @@ pub fn to_substrait_groupings( } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, + state, exprs, schema, extensions, @@ -845,7 +851,7 @@ pub fn to_substrait_groupings( )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, + state, exprs, schema, extensions, @@ -857,7 +863,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, expr: &Expr, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -865,13 +871,13 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(state, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) }); } let function_anchor = extensions.register_function(func.name().to_string()); Ok(Measure { @@ -889,14 +895,14 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), + Some(f) => Some(to_substrait_rex(state, f, schema, 0, extensions)?), None => None } }) } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extensions) + to_substrait_agg_measure(state, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -908,7 +914,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -920,7 +926,7 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?), + expr: Some(to_substrait_rex(state, &sort.expr, schema, 0, extensions)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } @@ -977,7 +983,7 @@ pub fn make_binary_op_scalar_func( /// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -991,10 +997,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) + .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1026,7 +1032,7 @@ pub fn to_substrait_rex( for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( - ctx, + state, arg, schema, col_ref_offset, @@ -1055,11 +1061,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -1083,11 +1089,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -1115,8 +1121,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; + let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } @@ -1131,7 +1137,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( - ctx, + state, e, schema, col_ref_offset, @@ -1144,14 +1150,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( - ctx, + state, r#if, schema, col_ref_offset, extensions, )?), then: Some(to_substrait_rex( - ctx, + state, then, schema, col_ref_offset, @@ -1163,7 +1169,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( - ctx, + state, e, schema, col_ref_offset, @@ -1182,7 +1188,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( - ctx, + state, expr, schema, col_ref_offset, @@ -1195,7 +1201,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) + to_substrait_rex(state, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1212,7 +1218,7 @@ pub fn to_substrait_rex( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( - ctx, + state, arg, schema, col_ref_offset, @@ -1223,12 +1229,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) + .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extensions)) + .map(|e| substrait_sort_field(state, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1249,7 +1255,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( - ctx, + state, *case_insensitive, *negated, expr, @@ -1265,10 +1271,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; + to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1301,7 +1307,7 @@ pub fn to_substrait_rex( } } Expr::Not(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "not", arg, schema, @@ -1309,7 +1315,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_null", arg, schema, @@ -1317,7 +1323,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_null", arg, schema, @@ -1325,7 +1331,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_true", arg, schema, @@ -1333,7 +1339,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_false", arg, schema, @@ -1341,7 +1347,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_unknown", arg, schema, @@ -1349,7 +1355,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_true", arg, schema, @@ -1357,7 +1363,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_false", arg, schema, @@ -1365,7 +1371,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_unknown", arg, schema, @@ -1373,7 +1379,7 @@ pub fn to_substrait_rex( extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "negate", arg, schema, @@ -1674,7 +1680,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, ignore_case: bool, negated: bool, expr: &Expr, @@ -1689,8 +1695,8 @@ fn make_substrait_like_expr( } else { extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let expr = to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset, extensions)?; let escape_char = to_substrait_literal_expr( &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), extensions, @@ -2088,7 +2094,7 @@ fn to_substrait_literal_expr( /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, @@ -2096,7 +2102,8 @@ fn to_substrait_unary_scalar_fn( extensions: &mut Extensions, ) -> Result { let function_anchor = extensions.register_function(fn_name.to_string()); - let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; + let substrait_expr = + to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2137,7 +2144,7 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -2147,7 +2154,7 @@ fn substrait_sort_field( asc, nulls_first, } = sort; - let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; + let e = to_substrait_rex(state, expr, schema, 0, extensions)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2190,6 +2197,7 @@ mod test { use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::DFSchema; + use datafusion::execution::SessionStateBuilder; #[test] fn round_trip_literals() -> Result<()> { @@ -2433,15 +2441,15 @@ mod test { #[tokio::test] async fn extended_expressions() -> Result<()> { - let ctx = SessionContext::new(); + let state = SessionStateBuilder::default().build(); // One expression, empty input schema let expr = Expr::Literal(ScalarValue::Int32(Some(42))); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let substrait = - to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx)?; - let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; assert_eq!(roundtrip_expr.input_schema, empty_schema); assert_eq!(roundtrip_expr.exprs.len(), 1); @@ -2463,9 +2471,9 @@ mod test { let substrait = to_substrait_extended_expr( &[(&expr1, &out1), (&expr2, &out2)], &input_schema, - &ctx, + &state, )?; - let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; assert_eq!(roundtrip_expr.input_schema, input_schema); assert_eq!(roundtrip_expr.exprs.len(), 2); @@ -2485,14 +2493,14 @@ mod test { #[tokio::test] async fn invalid_extended_expression() { - let ctx = SessionContext::new(); + let state = SessionStateBuilder::default().build(); // Not ok if input schema is missing field referenced by expr let expr = Expr::Column("missing".into()); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx); + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } diff --git a/datafusion/substrait/src/logical_plan/state.rs b/datafusion/substrait/src/logical_plan/state.rs new file mode 100644 index 0000000000000..0bd749c1105db --- /dev/null +++ b/datafusion/substrait/src/logical_plan/state.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::{ + catalog::TableProvider, + error::{DataFusionError, Result}, + execution::{registry::SerializerRegistry, FunctionRegistry, SessionState}, + sql::TableReference, +}; + +/// This trait provides the context needed to transform a substrait plan into a +/// [`datafusion::logical_expr::LogicalPlan`] (via [`super::consumer::from_substrait_plan`]) +/// and back again into a substrait plan (via [`super::producer::to_substrait_plan`]). +/// +/// The context is declared as a trait to decouple the substrait plan encoder / +/// decoder from the [`SessionState`], potentially allowing users to define +/// their own slimmer context just for serializing and deserializing substrait. +/// +/// [`SessionState`] implements this trait. +#[async_trait] +pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry { + /// Return [SerializerRegistry] for extensions + fn serializer_registry(&self) -> &Arc; + + async fn table( + &self, + reference: &TableReference, + ) -> Result>>; +} + +#[async_trait] +impl SubstraitPlanningState for SessionState { + fn serializer_registry(&self) -> &Arc { + self.serializer_registry() + } + + async fn table( + &self, + reference: &TableReference, + ) -> Result>, DataFusionError> { + let table = reference.table().to_string(); + let schema = self.schema_for_ref(reference.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } +} diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 6b81e33dfc374..4278671777fda 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -38,7 +38,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<() pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = producer::to_substrait_plan(&plan, ctx)?; + let proto = producer::to_substrait_plan(&plan, &ctx.state())?; let mut protobuf_out = Vec::::new(); proto.encode(&mut protobuf_out).map_err(|e| { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index bc38ef82977f3..219f656bb471e 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -41,7 +41,7 @@ mod tests { .expect("failed to parse json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; - let plan = from_substrait_plan(&ctx, &proto).await?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; Ok(format!("{}", plan)) } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index ac66177ed796f..08537d0d110f8 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -33,7 +33,7 @@ mod tests { "tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json", ); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -51,7 +51,7 @@ mod tests { "tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json", ); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -91,8 +91,8 @@ mod tests { \n TableScan: data" ); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; // note how the Projections are not flattened assert_eq!( format!("{}", plan2), @@ -115,8 +115,8 @@ mod tests { \n TableScan: data" ); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index b136b0af19c29..043808456176a 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -29,7 +29,7 @@ mod tests { async fn contains_function_test() -> Result<()> { let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index f4e34af35d78e..65f404bbda555 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -38,7 +38,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -63,7 +63,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/select_window.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -82,7 +82,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d4e2d48885ae6..d03ab5182028a 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -979,8 +979,8 @@ async fn extension_logical_plan() -> Result<()> { }), }); - let proto = to_substrait_plan(&ext_plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&ext_plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan1str = format!("{ext_plan}"); let plan2str = format!("{plan2}"); @@ -1081,8 +1081,8 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { partitioning_scheme: Partitioning::RoundRobinBatch(8), }); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; assert_eq!(format!("{plan}"), format!("{plan2}")); @@ -1098,8 +1098,8 @@ async fn roundtrip_repartition_hash() -> Result<()> { partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), }); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; assert_eq!(format!("{plan}"), format!("{plan2}")); @@ -1199,8 +1199,8 @@ async fn assert_expected_plan_unoptimized( let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_unoptimized_plan(); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; println!("{plan}"); println!("{plan2}"); @@ -1225,8 +1225,8 @@ async fn assert_expected_plan( let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan}"); @@ -1250,7 +1250,7 @@ async fn assert_expected_plan_substrait( ) -> Result<()> { let ctx = create_context().await?; - let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; @@ -1265,7 +1265,7 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { let expected = ctx.sql(sql).await?.into_optimized_plan()?; - let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; @@ -1280,8 +1280,8 @@ async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; // Format plan string and replace all None's with 0 @@ -1301,12 +1301,12 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { let ctx = create_context().await?; let df_a = ctx.sql(sql_with_alias).await?; - let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?; - let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?; + let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx.state())?; + let plan_with_alias = from_substrait_plan(&ctx.state(), &proto_a).await?; let df = ctx.sql(sql_no_alias).await?; - let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; - let plan = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; println!("{plan_with_alias}"); println!("{plan}"); @@ -1323,8 +1323,8 @@ async fn roundtrip_logical_plan_with_ctx( plan: LogicalPlan, ctx: SessionContext, ) -> Result> { - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan}"); diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index 54d55d1b6f10e..e28c63312788f 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -45,7 +45,7 @@ mod tests { // Read substrait plan from file let proto = serializer::deserialize(path).await?; // Check plan equality - let plan = from_substrait_plan(&ctx, &proto).await?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; let plan_str_ref = format!("{plan_ref}"); let plan_str = format!("{plan}"); assert_eq!(plan_str_ref, plan_str); @@ -60,7 +60,7 @@ mod tests { let ctx = create_context().await?; let table = provider_as_source(ctx.table_provider("data").await?); let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?; - let convert_result = to_substrait_plan(&table_scan, &ctx); + let convert_result = to_substrait_plan(&table_scan, &ctx.state()); assert!(convert_result.is_ok()); Ok(()) @@ -78,7 +78,9 @@ mod tests { \n TableScan: data projection=[a, b]", ); - let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + let plan = to_substrait_plan(&datafusion_plan, &ctx.state())? + .as_ref() + .clone(); let relation = plan.relations.first().unwrap().rel_type.as_ref(); let root_rel = match relation { @@ -121,7 +123,9 @@ mod tests { \n TableScan: data projection=[a, b, c]", ); - let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + let plan = to_substrait_plan(&datafusion_plan, &ctx.state())? + .as_ref() + .clone(); let relation = plan.relations.first().unwrap().rel_type.as_ref(); let root_rel = match relation { diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index 5ae586afe56f8..c77bf1489f4e7 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -65,7 +65,7 @@ mod tests { vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -86,7 +86,7 @@ mod tests { ("c", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -109,7 +109,7 @@ mod tests { ("b", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -128,7 +128,7 @@ mod tests { vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; let ctx = generate_context_with_table("DATA", df_schema)?; - let res = from_substrait_plan(&ctx, &proto_plan).await; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) } @@ -140,7 +140,7 @@ mod tests { let ctx = generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; - let res = from_substrait_plan(&ctx, &proto_plan).await; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) } From c3e1173351328ac1bd966659d429e8986619d40e Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 20 Nov 2024 17:23:09 -0500 Subject: [PATCH 40/45] Fixed issue with md5 not support LargeUtf8 correctly (#13502) * Fix md5 return_type to only return Utf8 as per current code impl. * Added tests from #13443 to verify fix --- datafusion/functions/src/crypto/md5.rs | 4 +- .../test_files/string/string_query.slt.part | 109 ++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 0f18fd47b4cf0..0e8ff1cd31928 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -64,11 +64,11 @@ impl ScalarUDFImpl for Md5Func { fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; Ok(match &arg_types[0] { - LargeUtf8 | LargeBinary => LargeUtf8, + LargeUtf8 | LargeBinary => Utf8, Utf8View | Utf8 | Binary => Utf8, Null => Null, Dictionary(_, t) => match **t { - LargeUtf8 | LargeBinary => LargeUtf8, + LargeUtf8 | LargeBinary => Utf8, Utf8 | Binary => Utf8, Null => Null, _ => { diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index f781b9dc33caf..c42a9384c5d0c 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -1373,3 +1373,112 @@ p percent NULL pan Tadeusz ma iść w kąt pan Tadeusz ma iść w kąt NULL _ _ NULL (empty) (empty) NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test md5 +# -------------------------------------- + +query T +select md5(ascii_1) from test_basic_operator; +---- +8aae3a73a9a43ee6b04dfd986fe9d136 +76515af83bcb9d6336fe42dba18e716d +84fc7720d5e7bf07115d91762843b8ad +e0c4c75d58916b22a41b6ea9bc46231f +354f047ba64552895b016bbdd60ab174 +d41d8cd98f00b204e9800998ecf8427e +0bcef9c45bd8a48eda1b26eb0c61c869 +b14a7b8059d9c055954c92674ce60032 +NULL +NULL + +# -------------------------------------- +# Test sha244 +# -------------------------------------- + +query ? +select sha224(ascii_1) from test_basic_operator; +---- +abd8be3961e5dbe324bc67f9a0211d5f7d81e556baadaff6218e4bfa +87a20c95932524a54a0263a621fe791a5d5fbc0e40242b59732d6bf5 +8dd0c8021fe87bbc1c0701bd3130e27a639dcd93083c3f1989ffdf26 +8f6caa44143a080541f083bb762107ce12224b271bfa8b36ece002ab +951336d101e034714ba1ca0535688f0300613e235814ed938cd25115 +d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f +fda2a4d4c5fb67cfd7fc817f59b543ae42f650aa4abd79934ca5ac55 +d365e3c7512c311d0df0528a850e6c827cbe508d13235fa91b545389 +NULL +NULL + +# -------------------------------------- +# Test sha256 +# -------------------------------------- + +query ? +select sha256(ascii_1) from test_basic_operator; +---- +c10873196eb1124ed74461c20a67094e395f2310f6305607b9694ee6b1ee8b43 +ec792d2e89af0d5b05c88ee1e5fe041ce2db94f84c3aabac4f7cfe20f00cd032 +053e9c5f1a29bea66ff896d7a8f217bf380b8e3973e7f13c1acbe14ef7fc947e +d8071166bbe6131a0acaf86019eeeca31c87ee4fda23b80eda0d094dbffee521 +fd86717aca41c558c78c19ab2b50691179a57ba5200bc7e3317be70efd4043ad +e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +bbf3f11cb5b43e700273a78d12de55e4a7eab741ed2abf13787a4d2dc832b8ec +d2e2adf7177b7a8afddbc12d1634cf23ea1a71020f6a1308070a16400fb68fde +NULL +NULL + +# -------------------------------------- +# Test sha384 +# -------------------------------------- + +query ? +select sha384(ascii_1) from test_basic_operator; +---- +33a2a749758403660d131256e08647f52e4efba74840e7ad55c77012ade611ec0dc815ab3fa777e98710d43f3345222b +7b525a4147696421c6119df0e983ee3d9ebcfa13b3e1dce2fb308f91863e236fde55b56b89936908999332f5a453845c +359ee4b366b1965e9ceb0bd529edcdc08c33b0348aa4cc2cf4114c7f18069d53f6a798482626393c46ed340995c34b4e +fe417fcff1b9b8cdbc4fba45fedcd882ccbeef438497647052809fd73f43bcf1a6214f543a91e7183d56c6ae8e7cb30e +7791b34dcc841235a8a074052bc12aa7090c0d72f09ec41b1521a67fa09b026a9c02d159b42428d7b528aa5ff7598fd4 +38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b +bba987e661a4158451c5e9870fe91f483064574a0d7485caef40f48d7846579859c7dddebd418cbc99ccaa1ebd3619ea +586b0fd9f8ec935c69a7dceb5560742f368962833023906d30fe1cf49c96ea6d22cea8c2b63cd18e7af08fbf9e47c3f9 +NULL +NULL + + +# -------------------------------------- +# Test sha512 +# -------------------------------------- + +query ? +select sha512(ascii_1) from test_basic_operator; +---- +93262eb44d649a02a83b78889fd813ce819759daabcee2ac433f1ea7feef44f521ac0eba5b5359d47c7a7146afbe064b55134a63ac713c0fcc4c48e11eed7109 +f02c73afb1e433d6cc7e9137bb4ed40791e8c6e7877ae26e7a1edc4ce98a945a61bdf883d985adbc03d74d67ac18d4981529be5f4f53a35ff7fcd3e9814592d7 +2f25e277902f07a4c5cdb54485487b50bae3acdd615cd5551f71f4e3d97077fbccfbf0c85f88d6766d132069a343b732c6e81080a2c3ed59caff0c6947f4c57a +cafc51edc3a949179a74a805be8d0c7991bfc849b01f773f4bcd5e7dbe51b6d71d65921d8025d375d501af6a1c1026ab76cd7f4811b91bb4544f7dcbb710fa1f +2f845edf0e9c9728fae627d4678dc8c35c9a7f22809d355aa5ddf96d9ca3539973ac7ff96bfc6720ce6a973f93b716e265ad719ee38a85e44d9316ac1b6c89a4 +cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e +91972aa34055bca20ddb643b9f817a547e5d4ad49b7ff16a7f828a8d72c4cb4a5679cff4da00f9fb6b2833de7eb3480b3b4a7c7c7b85a39028de55acaf2d8812 +bbbe7f2559c7953d281fba7f25258063dbc8a55c5b9fdfcd334ecd64a8d7d8980c6f6ee0457bf496bcff747991f741446f1814222678dfa7457f1ad3a6f848b3 +NULL +NULL + +# -------------------------------------- +# Test DIGEST +# -------------------------------------- + +query ? +select DIGEST(ascii_1, 'sha256') from test_basic_operator; +---- +c10873196eb1124ed74461c20a67094e395f2310f6305607b9694ee6b1ee8b43 +ec792d2e89af0d5b05c88ee1e5fe041ce2db94f84c3aabac4f7cfe20f00cd032 +053e9c5f1a29bea66ff896d7a8f217bf380b8e3973e7f13c1acbe14ef7fc947e +d8071166bbe6131a0acaf86019eeeca31c87ee4fda23b80eda0d094dbffee521 +fd86717aca41c558c78c19ab2b50691179a57ba5200bc7e3317be70efd4043ad +e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +bbf3f11cb5b43e700273a78d12de55e4a7eab741ed2abf13787a4d2dc832b8ec +d2e2adf7177b7a8afddbc12d1634cf23ea1a71020f6a1308070a16400fb68fde +NULL +NULL \ No newline at end of file From 240402d28a2731e6e272f059dacfc1129d70c175 Mon Sep 17 00:00:00 2001 From: Dmitrii Blaginin Date: Thu, 21 Nov 2024 00:04:46 +0000 Subject: [PATCH 41/45] Coerce Array inner types (#13452) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Coerce Array inner types * `coerce_list_children` macro → fn * Rearrange & simplify patterns * Add casting for `FixedSizeList` * Add sql logic test * Fix test * Expand error * Switch to `Arc::clone` * OR nullable-s * Add successful test * Switch to `type_union_resolution` * Add a note on `coerce_list_children` * Handle different list types inside `type_union_resolution_coercion` * Add a `make_array` test with different array types * Add the test value * Update datafusion/sqllogictest/test_files/array.slt Co-authored-by: Jay Zhan * Reorder match --------- Co-authored-by: Jay Zhan --- .../expr-common/src/type_coercion/binary.rs | 73 +++++++++++++++---- datafusion/functions-nested/src/make_array.rs | 3 +- datafusion/sqllogictest/test_files/array.slt | 12 +++ datafusion/sqllogictest/test_files/union.slt | 16 ++++ 4 files changed, 86 insertions(+), 18 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c32b4951db448..bff74252df7b7 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -480,11 +480,6 @@ fn type_union_resolution_coercion( let new_value_type = type_union_resolution_coercion(value_type, other_type); new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) } - (DataType::List(lhs), DataType::List(rhs)) => { - let new_item_type = - type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); - new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) - } (DataType::Struct(lhs), DataType::Struct(rhs)) => { if lhs.len() != rhs.len() { return None; @@ -529,6 +524,7 @@ fn type_union_resolution_coercion( // Numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) @@ -1138,27 +1134,46 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()]; + Some(Arc::new( + (**lhs_field) + .clone() + .with_data_type(type_union_resolution(&data_types)?) + .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()), + )) +} + /// Coercion rules for list types. fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (List(_), List(_)) => Some(lhs_type.clone()), - (LargeList(_), List(_)) => Some(lhs_type.clone()), - (List(_), LargeList(_)) => Some(rhs_type.clone()), - (LargeList(_), LargeList(_)) => Some(lhs_type.clone()), - (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()), // Coerce to the left side FixedSizeList type if the list lengths are the same, // otherwise coerce to list with the left type for dynamic length - (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => { + (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => { if ls == rs { - Some(lhs_type.clone()) + Some(FixedSizeList( + coerce_list_children(lhs_field, rhs_field)?, + *rs, + )) } else { - Some(List(Arc::clone(lf))) + Some(List(coerce_list_children(lhs_field, rhs_field)?)) } } - (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()), + // LargeList on any side + ( + LargeList(lhs_field), + List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _), + ) + | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => { + Some(LargeList(coerce_list_children(lhs_field, rhs_field)?)) + } + // Lists on both sides + (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _)) + | (FixedSizeList(lhs_field, _), List(rhs_field)) => { + Some(List(coerce_list_children(lhs_field, rhs_field)?)) + } _ => None, } } @@ -2105,10 +2120,36 @@ mod tests { DataType::List(Arc::clone(&inner_field)) ); + // Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible + let inner_timestamp_field = Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )); + let result_type = get_input_types( + &DataType::List(Arc::clone(&inner_field)), + &Operator::Eq, + &DataType::List(Arc::clone(&inner_timestamp_field)), + ); + assert!(result_type.is_err()); + // TODO add other data type Ok(()) } + #[test] + fn test_list_coercion() { + let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false))); + + let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); + + let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!( + coerced_type, + DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) + ); // nullable because the RHS is nullable + } + #[test] fn test_type_coercion_logical_op() -> Result<()> { test_coercion_binary_rule!( diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index de67b0ae38749..c84b6f010968c 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -26,7 +26,7 @@ use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, NullArray, OffsetSizeTrait, }; use arrow_buffer::OffsetBuffer; -use arrow_schema::DataType::{LargeList, List, Null}; +use arrow_schema::DataType::{List, Null}; use arrow_schema::{DataType, Field}; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; use datafusion_expr::binary::{ @@ -198,7 +198,6 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let array = new_null_array(&DataType::Int64, length); Ok(Arc::new(array_into_list_array_nullable(array))) } - LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1e60699a1f653..e6676d683f914 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1155,6 +1155,18 @@ select column1, column5 from arrays_values_without_nulls; [21, 22, 23, 24, 25, 26, 27, 28, 29, 30] [6, 7] [31, 32, 33, 34, 35, 26, 37, 38, 39, 40] [8, 9] +# make array with arrays of different types +query ? +select make_array(make_array(1), arrow_cast(make_array(-1), 'LargeList(Int8)')) +---- +[[1], [-1]] + +query T +select arrow_typeof(make_array(make_array(1), arrow_cast(make_array(-1), 'LargeList(Int8)'))); +---- +List(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + query ??? select make_array(column1), make_array(column1, column5), diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index fb7afdda2ea82..b5e82f613a462 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -761,3 +761,19 @@ SELECT NULL WHERE FALSE; ---- 0.5 1 + +# Test Union of List Types. Issue: https://github.com/apache/datafusion/issues/12291 +query error DataFusion error: type_coercion\ncaused by\nError during planning: Incompatible inputs for Union: Previous inputs were of type List(.*), but got incompatible type List(.*) on column 'x' +SELECT make_array(2) x UNION ALL SELECT make_array(now()) x; + +query ? +select make_array(arrow_cast(2, 'UInt8')) x UNION ALL SELECT make_array(arrow_cast(-2, 'Int8')) x; +---- +[-2] +[2] + +query ? +select make_array(make_array(1)) x UNION ALL SELECT make_array(arrow_cast(make_array(-1), 'LargeList(Int8)')) x; +---- +[[-1]] +[[1]] From a2811fc85d469c879e3d4db6ceb3fa13fbf263be Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 21 Nov 2024 00:49:55 -0500 Subject: [PATCH 42/45] Update arrow/parquet to arrow/parquet `53.3.0` (#13508) * Update arrow/parquet to arrow 53.3.0 * Update Cargo.lock * fix ci test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 Co-authored-by: jayzhan211 --- Cargo.toml | 18 ++-- datafusion-cli/Cargo.lock | 100 +++++++++--------- .../test_files/string/string_literal.slt | 90 ++++++++-------- .../test_files/string/string_view.slt | 13 ++- 4 files changed, 116 insertions(+), 105 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0011539156326..e947afff8f4fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,22 +74,22 @@ version = "43.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "53.2.0", features = [ +arrow = { version = "53.3.0", features = [ "prettyprint", ] } -arrow-array = { version = "53.2.0", default-features = false, features = [ +arrow-array = { version = "53.3.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "53.2.0", default-features = false } -arrow-flight = { version = "53.2.0", features = [ +arrow-buffer = { version = "53.3.0", default-features = false } +arrow-flight = { version = "53.3.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "53.2.0", default-features = false, features = [ +arrow-ipc = { version = "53.3.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "53.2.0", default-features = false } -arrow-schema = { version = "53.2.0", default-features = false } -arrow-string = { version = "53.2.0", default-features = false } +arrow-ord = { version = "53.3.0", default-features = false } +arrow-schema = { version = "53.3.0", default-features = false } +arrow-string = { version = "53.3.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -131,7 +131,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "53.2.0", default-features = false, features = [ +parquet = { version = "53.3.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c5576b7e7d444..8afb096df55f1 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -173,9 +173,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4caf25cdc4a985f91df42ed9e9308e1adbcd341a31a72605c697033fcef163e3" +checksum = "c91839b07e474b3995035fd8ac33ee54f9c9ccbbb1ea33d9909c71bffdf1259d" dependencies = [ "arrow-arith", "arrow-array", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91f2dfd1a7ec0aca967dfaa616096aec49779adc8eccec005e2f5e4111b1192a" +checksum = "855c57c4efd26722b044dcd3e348252560e3e0333087fb9f6479dc0bf744054f" dependencies = [ "arrow-array", "arrow-buffer", @@ -209,9 +209,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39387ca628be747394890a6e47f138ceac1aa912eab64f02519fed24b637af8" +checksum = "bd03279cea46569acf9295f6224fbc370c5df184b4d2ecfe97ccb131d5615a7f" dependencies = [ "ahash", "arrow-buffer", @@ -220,15 +220,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.1", "num", ] [[package]] name = "arrow-buffer" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e51e05228852ffe3eb391ce7178a0f97d2cf80cc6ef91d3c4a6b3cb688049ec" +checksum = "9e4a9b9b1d6d7117f6138e13bc4dd5daa7f94e671b70e8c9c4dc37b4f5ecfc16" dependencies = [ "bytes", "half", @@ -237,9 +237,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d09aea56ec9fa267f3f3f6cdab67d8a9974cbba90b3aa38c8fe9d0bb071bd8c1" +checksum = "bc70e39916e60c5b7af7a8e2719e3ae589326039e1e863675a008bee5ffe90fd" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c07b5232be87d115fde73e32f2ca7f1b353bff1b44ac422d3c6fc6ae38f11f0d" +checksum = "789b2af43c1049b03a8d088ff6b2257cdcea1756cd76b174b1f2600356771b97" dependencies = [ "arrow-array", "arrow-buffer", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98ae0af50890b494cebd7d6b04b35e896205c1d1df7b29a6272c5d0d0249ef5" +checksum = "e4e75edf21ffd53744a9b8e3ed11101f610e7ceb1a29860432824f1834a1f623" dependencies = [ "arrow-buffer", "arrow-schema", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed91bdeaff5a1c00d28d8f73466bcb64d32bbd7093b5a30156b4b9f4dba3eee" +checksum = "d186a909dece9160bf8312f5124d797884f608ef5435a36d9d608e0b2a9bcbf8" dependencies = [ "arrow-array", "arrow-buffer", @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0471f51260a5309307e5d409c9dc70aede1cd9cf1d4ff0f0a1e8e1a2dd0e0d3c" +checksum = "b66ff2fedc1222942d0bd2fd391cb14a85baa3857be95c9373179bd616753b85" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2883d7035e0b600fb4c30ce1e50e66e53d8656aa729f2bfa4b51d359cf3ded52" +checksum = "ece7b5bc1180e6d82d1a60e1688c199829e8842e38497563c3ab6ea813e527fd" dependencies = [ "arrow-array", "arrow-buffer", @@ -339,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "552907e8e587a6fde4f8843fd7a27a576a260f65dab6c065741ea79f633fc5be" +checksum = "745c114c8f0e8ce211c83389270de6fbe96a9088a7b32c2a041258a443fe83ff" dependencies = [ "ahash", "arrow-array", @@ -353,15 +353,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "539ada65246b949bd99ffa0881a9a15a4a529448af1a07a9838dd78617dafab1" +checksum = "b95513080e728e4cec37f1ff5af4f12c9688d47795d17cda80b6ec2cf74d4678" [[package]] name = "arrow-select" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6259e566b752da6dceab91766ed8b2e67bf6270eb9ad8a6e07a33c1bede2b125" +checksum = "8e415279094ea70323c032c6e739c48ad8d80e78a09bef7117b8718ad5bf3722" dependencies = [ "ahash", "arrow-array", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3179ccbd18ebf04277a095ba7321b93fd1f774f18816bd5f6b3ce2f594edb6c" +checksum = "11d956cae7002eb8d83a27dbd34daaea1cf5b75852f0b84deb4d93a276e92bbf" dependencies = [ "arrow-array", "arrow-buffer", @@ -1158,9 +1158,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" +checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" dependencies = [ "quote", "syn", @@ -1934,9 +1934,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", @@ -2120,14 +2120,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -2162,7 +2162,7 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-util", "rustls 0.23.17", "rustls-native-certs 0.8.0", @@ -2183,7 +2183,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.5.0", + "hyper 1.5.1", "pin-project-lite", "socket2", "tokio", @@ -2392,9 +2392,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" [[package]] name = "jobserver" @@ -2778,7 +2778,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.5.0", + "hyper 1.5.1", "itertools", "md-5", "parking_lot", @@ -2855,9 +2855,9 @@ dependencies = [ [[package]] name = "parquet" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e" +checksum = "2b449890367085eb65d7d3321540abc3d7babbd179ce31df0016e90719114191" dependencies = [ "ahash", "arrow-array", @@ -2874,7 +2874,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.1", "lz4_flex", "num", "num-bigint", @@ -3256,11 +3256,11 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3824,9 +3824,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] @@ -4137,9 +4137,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-segmentation" diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 493da64063bca..145081f91a30d 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -901,7 +901,7 @@ SELECT '\' LIKE '\\', '\\' LIKE '\\' ---- -false true false true false false true +true false false true false true false # if "%%" in the pattern was simplified to "%", the pattern semantics would change query BBBBB @@ -1002,7 +1002,7 @@ NULL \%abc NULL \ NULL NULL \ (empty) false \ \ true -\ \\ false +\ \\ true \ \\\ false \ \\\\ false \ a false @@ -1010,10 +1010,10 @@ NULL \%abc NULL \ \\a false \ % true \ \% false -\ \\% false +\ \\% true \ %% true \ \%% false -\ \\%% false +\ \\%% true \ _ true \ \_ false \ \\_ false @@ -1028,21 +1028,21 @@ NULL \%abc NULL \\ NULL NULL \\ (empty) false \\ \ false -\\ \\ true -\\ \\\ false -\\ \\\\ false +\\ \\ false +\\ \\\ true +\\ \\\\ true \\ a false \\ \a false \\ \\a false \\ % true \\ \% false -\\ \\% false +\\ \\% true \\ %% true \\ \%% false -\\ \\%% false +\\ \\%% true \\ _ false \\ \_ false -\\ \\_ false +\\ \\_ true \\ __ true \\ \__ false \\ \\__ false @@ -1055,23 +1055,23 @@ NULL \%abc NULL \\\ (empty) false \\\ \ false \\\ \\ false -\\\ \\\ true +\\\ \\\ false \\\ \\\\ false \\\ a false \\\ \a false \\\ \\a false \\\ % true \\\ \% false -\\\ \\% false +\\\ \\% true \\\ %% true \\\ \%% false -\\\ \\%% false +\\\ \\%% true \\\ _ false \\\ \_ false \\\ \\_ false \\\ __ false \\\ \__ false -\\\ \\__ false +\\\ \\__ true \\\ abc false \\\ a_c false \\\ a\_c false @@ -1082,16 +1082,16 @@ NULL \%abc NULL \\\\ \ false \\\\ \\ false \\\\ \\\ false -\\\\ \\\\ true +\\\\ \\\\ false \\\\ a false \\\\ \a false \\\\ \\a false \\\\ % true \\\\ \% false -\\\\ \\% false +\\\\ \\% true \\\\ %% true \\\\ \%% false -\\\\ \\%% false +\\\\ \\%% true \\\\ _ false \\\\ \_ false \\\\ \\_ false @@ -1110,7 +1110,7 @@ a \\ false a \\\ false a \\\\ false a a true -a \a false +a \a true a \\a false a % true a \% false @@ -1136,17 +1136,17 @@ a \%abc false \a \\\ false \a \\\\ false \a a false -\a \a true -\a \\a false +\a \a false +\a \\a true \a % true \a \% false -\a \\% false +\a \\% true \a %% true \a \%% false -\a \\%% false +\a \\%% true \a _ false \a \_ false -\a \\_ false +\a \\_ true \a __ true \a \__ false \a \\__ false @@ -1163,19 +1163,19 @@ a \%abc false \\a \\\\ false \\a a false \\a \a false -\\a \\a true +\\a \\a false \\a % true \\a \% false -\\a \\% false +\\a \\% true \\a %% true \\a \%% false -\\a \\%% false +\\a \\%% true \\a _ false \\a \_ false \\a \\_ false \\a __ false \\a \__ false -\\a \\__ false +\\a \\__ true \\a abc false \\a a_c false \\a a\_c false @@ -1224,7 +1224,7 @@ a \%abc false \% \\%% true \% _ false \% \_ false -\% \\_ false +\% \\_ true \% __ true \% \__ false \% \\__ false @@ -1244,16 +1244,16 @@ a \%abc false \\% \\a false \\% % true \\% \% false -\\% \\% false +\\% \\% true \\% %% true \\% \%% false -\\% \\%% false +\\% \\%% true \\% _ false \\% \_ false \\% \\_ false \\% __ false \\% \__ false -\\% \\__ false +\\% \\__ true \\% abc false \\% a_c false \\% a\_c false @@ -1296,7 +1296,7 @@ a \%abc false \%% \\a false \%% % true \%% \% false -\%% \\% false +\%% \\% true \%% %% true \%% \%% false \%% \\%% true @@ -1305,7 +1305,7 @@ a \%abc false \%% \\_ false \%% __ false \%% \__ false -\%% \\__ false +\%% \\__ true \%% abc false \%% a_c false \%% a\_c false @@ -1322,10 +1322,10 @@ a \%abc false \\%% \\a false \\%% % true \\%% \% false -\\%% \\% false +\\%% \\% true \\%% %% true \\%% \%% false -\\%% \\%% false +\\%% \\%% true \\%% _ false \\%% \_ false \\%% \\_ false @@ -1374,10 +1374,10 @@ _ \%abc false \_ \\a false \_ % true \_ \% false -\_ \\% false +\_ \\% true \_ %% true \_ \%% false -\_ \\%% false +\_ \\%% true \_ _ false \_ \_ false \_ \\_ true @@ -1400,16 +1400,16 @@ _ \%abc false \\_ \\a false \\_ % true \\_ \% false -\\_ \\% false +\\_ \\% true \\_ %% true \\_ \%% false -\\_ \\%% false +\\_ \\%% true \\_ _ false \\_ \_ false \\_ \\_ false \\_ __ false \\_ \__ false -\\_ \\__ false +\\_ \\__ true \\_ abc false \\_ a_c false \\_ a\_c false @@ -1452,10 +1452,10 @@ __ \%abc false \__ \\a false \__ % true \__ \% false -\__ \\% false +\__ \\% true \__ %% true \__ \%% false -\__ \\%% false +\__ \\%% true \__ _ false \__ \_ false \__ \\_ false @@ -1478,10 +1478,10 @@ __ \%abc false \\__ \\a false \\__ % true \\__ \% false -\\__ \\% false +\\__ \\% true \\__ %% true \\__ \%% false -\\__ \\%% false +\\__ \\%% true \\__ _ false \\__ \_ false \\__ \\_ false @@ -1608,7 +1608,7 @@ a\_c \%abc false \%abc \\a false \%abc % true \%abc \% false -\%abc \\% false +\%abc \\% true \%abc %% true \%abc \%% false \%abc \\%% true diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 5a08f3f5447a5..aa41cbb8119ec 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -39,8 +39,19 @@ drop table test_source # TODO: Revisit this issue after upgrading to the arrow-rs version that includes apache/arrow-rs#6671. # see issue https://github.com/apache/datafusion/issues/13329 -query error DataFusion error: Arrow error: Compute error: bit_length not supported for Utf8View +query IIII select bit_length(ascii_1), bit_length(ascii_2), bit_length(unicode_1), bit_length(unicode_2) from test_basic_operator; +---- +48 8 144 32 +72 72 176 176 +56 8 240 64 +88 88 104 256 +56 24 216 288 +0 8 0 0 +8 16 0 0 +8 16 0 0 +NULL 8 NULL NULL +NULL 8 NULL 32 # # common test for string-like functions and operators From 9fb7aee95c5fcf177609963cedadf443ba6fe1b7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 21 Nov 2024 01:05:57 -0500 Subject: [PATCH 43/45] Minor: Add debug log message for creating GroupValuesRows (#13506) * Minor: Add debug log message for creating GroupValuesRows * fmt --- datafusion/physical-plan/src/aggregates/group_values/row.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index de0ae2e07dd29..8e0f0a3d65070 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -27,6 +27,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use log::debug; use std::mem::size_of; use std::sync::Arc; @@ -80,6 +81,9 @@ pub struct GroupValuesRows { impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { + // Print a debugging message, so it is clear when the (slower) fallback + // GroupValuesRows is used. + debug!("Creating GroupValuesRows for schema: {}", schema); let row_converter = RowConverter::new( schema .fields() From e7d9504fd9d12cc8242b0b5d4b92c3fad4fb0e97 Mon Sep 17 00:00:00 2001 From: delamarch3 <68732277+delamarch3@users.noreply.github.com> Date: Thu, 21 Nov 2024 14:02:33 +0000 Subject: [PATCH 44/45] Unparse struct to sql (#13493) * unparse struct to sql * add roundtrip statement test for named_struct * quote keys if needed * add roundtrip statement test for get_field * improve error messages * fmt * fmt * match string literals only Co-authored-by: Jax Liu --------- Co-authored-by: Jax Liu --- datafusion/sql/src/unparser/expr.rs | 61 ++++++++++++++++++++++- datafusion/sql/tests/cases/plan_to_sql.rs | 4 +- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index f1f28258f9bd6..ae2607de00a21 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -462,7 +462,9 @@ impl Unparser<'_> { match func_name { "make_array" => self.make_array_to_sql(args), "array_element" => self.array_element_to_sql(args), - // TODO: support for the construct and access functions of the `map` and `struct` types + "named_struct" => self.named_struct_to_sql(args), + "get_field" => self.get_field_to_sql(args), + // TODO: support for the construct and access functions of the `map` type _ => self.scalar_function_to_sql_internal(func_name, args), } } @@ -514,6 +516,57 @@ impl Unparser<'_> { }) } + fn named_struct_to_sql(&self, args: &[Expr]) -> Result { + if args.len() % 2 != 0 { + return internal_err!("named_struct must have an even number of arguments"); + } + + let args = args + .chunks_exact(2) + .map(|chunk| { + let key = match &chunk[0] { + Expr::Literal(ScalarValue::Utf8(Some(s))) => self.new_ident_quoted_if_needs(s.to_string()), + _ => return internal_err!("named_struct expects even arguments to be strings, but received: {:?}", &chunk[0]) + }; + + Ok(ast::DictionaryField { + key, + value: Box::new(self.expr_to_sql(&chunk[1])?), + }) + }) + .collect::>>()?; + + Ok(ast::Expr::Dictionary(args)) + } + + fn get_field_to_sql(&self, args: &[Expr]) -> Result { + if args.len() != 2 { + return internal_err!("get_field must have exactly 2 arguments"); + } + + let mut id = match &args[0] { + Expr::Column(col) => match self.col_to_sql(col)? { + ast::Expr::Identifier(ident) => vec![ident], + ast::Expr::CompoundIdentifier(idents) => idents, + other => return internal_err!("expected col_to_sql to return an Identifier or CompoundIdentifier, but received: {:?}", other), + }, + _ => return internal_err!("get_field expects first argument to be column, but received: {:?}", &args[0]), + }; + + let field = match &args[1] { + Expr::Literal(lit) => self.new_ident_quoted_if_needs(lit.to_string()), + _ => { + return internal_err!( + "get_field expects second argument to be a string, but received: {:?}", + &args[0] + ) + } + }; + id.push(field); + + Ok(ast::Expr::CompoundIdentifier(id)) + } + pub fn sort_to_sql(&self, sort: &Sort) -> Result { let Sort { expr, @@ -1524,6 +1577,7 @@ mod tests { Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; + use datafusion_functions::expr_fn::{get_field, named_struct}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use datafusion_functions_nested::expr_fn::{array_element, make_array}; @@ -1937,6 +1991,11 @@ mod tests { array_element(make_array(vec![lit(1), lit(2), lit(3)]), lit(1)), "[1, 2, 3][1]", ), + ( + named_struct(vec![lit("a"), lit("1"), lit("b"), lit(2)]), + "{a: '1', b: 2}", + ), + (get_field(col("a.b"), "c"), "a.b.c"), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index f9d97cdc74af9..58d99549de319 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -188,7 +188,9 @@ fn roundtrip_statement() -> Result<()> { "SELECT ARRAY[1, 2, 3][1]", "SELECT [1, 2, 3]", "SELECT [1, 2, 3][1]", - "SELECT left[1] FROM array" + "SELECT left[1] FROM array", + "SELECT {a:1, b:2}", + "SELECT s.a FROM (SELECT {a:1, b:2} AS s)" ]; // For each test sql string, we transform as follows: From edbd93aacf0b2397cbb1051b1da261fa008c23dd Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Thu, 21 Nov 2024 16:08:42 +0000 Subject: [PATCH 45/45] Add `ScalarUDFImpl::invoke_with_args` to support passing the return type created for the udf instance (#13290) * Added support for `ScalarUDFImpl::invoke_with_return_type` where the invoke is passed the return type created for the udf instance * Do not yet deprecate invoke_batch, add docs to invoke_with_args * add ticket reference --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udf.rs | 98 +++++++++++-------- datafusion/functions/benches/random.rs | 2 + datafusion/functions/src/core/version.rs | 1 + .../functions/src/datetime/to_local_time.rs | 9 +- .../functions/src/datetime/to_timestamp.rs | 4 +- .../functions/src/datetime/to_unixtime.rs | 1 + datafusion/functions/src/math/log.rs | 20 ++-- datafusion/functions/src/math/power.rs | 4 +- datafusion/functions/src/math/signum.rs | 2 + datafusion/functions/src/regex/regexpcount.rs | 24 ++--- datafusion/functions/src/utils.rs | 7 +- .../physical-expr/src/scalar_function.rs | 8 +- 13 files changed, 107 insertions(+), 75 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 27b2d71b1f425..d8b829f27e7d4 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; -pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl}; +pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udf_docs::{DocSection, Documentation, DocumentationBuilder}; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 1a5d50477b1c8..57b8d9c6b02e8 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -203,10 +203,7 @@ impl ScalarUDF { self.inner.simplify(args, info) } - /// Invoke the function on `args`, returning the appropriate result. - /// - /// See [`ScalarUDFImpl::invoke`] for more details. - #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] + #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")] pub fn invoke(&self, args: &[ColumnarValue]) -> Result { #[allow(deprecated)] self.inner.invoke(args) @@ -216,20 +213,27 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } - /// Invoke the function with `args` and number of rows, returning the appropriate result. - /// - /// See [`ScalarUDFImpl::invoke_batch`] for more details. + #[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")] pub fn invoke_batch( &self, args: &[ColumnarValue], number_rows: usize, ) -> Result { + #[allow(deprecated)] self.inner.invoke_batch(args, number_rows) } + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_with_args`] for details. + pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.inner.invoke_with_args(args) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// - /// See [`ScalarUDFImpl::invoke_no_args`] for more details. + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke_no_args(&self, number_rows: usize) -> Result { #[allow(deprecated)] @@ -324,7 +328,17 @@ where } } -/// Trait for implementing [`ScalarUDF`]. +pub struct ScalarFunctionArgs<'a> { + // The evaluated arguments to the function + pub args: &'a [ColumnarValue], + // The number of rows in record batch being evaluated + pub number_rows: usize, + // The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) + // when creating the physical expression from the logical expression + pub return_type: &'a DataType, +} + +/// Trait for implementing user defined scalar functions. /// /// This trait exposes the full API for implementing user defined functions and /// can be used to implement any function. @@ -332,18 +346,19 @@ where /// See [`advanced_udf.rs`] for a full example with complete implementation and /// [`ScalarUDF`] for other available options. /// -/// /// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// /// # Basic Example /// ``` /// # use std::any::Any; /// # use std::sync::OnceLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; /// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; /// +/// /// This struct for a simple UDF that adds one to an int32 /// #[derive(Debug)] /// struct AddOne { /// signature: Signature, @@ -356,7 +371,7 @@ where /// } /// } /// } -/// +/// /// static DOCUMENTATION: OnceLock = OnceLock::new(); /// /// fn get_doc() -> &'static Documentation { @@ -383,7 +398,9 @@ where /// Ok(DataType::Int32) /// } /// // The actual implementation would add one to the argument -/// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { +/// unimplemented!() +/// } /// fn documentation(&self) -> Option<&Documentation> { /// Some(get_doc()) /// } @@ -479,24 +496,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Invoke the function on `args`, returning the appropriate result /// - /// The function will be invoked passed with the slice of [`ColumnarValue`] - /// (either scalar or array). - /// - /// If the function does not take any arguments, please use [invoke_no_args] - /// instead and return [not_impl_err] for this function. - /// - /// - /// # Performance - /// - /// For the best performance, the implementations of `invoke` should handle - /// the common case when one or more of their arguments are constant values - /// (aka [`ColumnarValue::Scalar`]). - /// - /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments - /// to arrays, which will likely be simpler code, but be slower. - /// - /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args - #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. + #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")] fn invoke(&self, _args: &[ColumnarValue]) -> Result { not_impl_err!( "Function {} does not implement invoke but called", @@ -507,17 +509,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Invoke the function with `args` and the number of rows, /// returning the appropriate result. /// - /// The function will be invoked with the slice of [`ColumnarValue`] - /// (either scalar or array). - /// - /// # Performance + /// Note: See notes on [`Self::invoke_with_args`] /// - /// For the best performance, the implementations should handle the common case - /// when one or more of their arguments are constant values (aka - /// [`ColumnarValue::Scalar`]). + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. /// - /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments - /// to arrays, which will likely be simpler code, but be slower. + /// See for more details. fn invoke_batch( &self, args: &[ColumnarValue], @@ -537,9 +534,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } + /// Invoke the function returning the appropriate result. + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + #[allow(deprecated)] + self.invoke_batch(args.args, args.number_rows) + } + /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. - #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] + /// + /// Note: This method is deprecated and will be removed in future releases. + /// User defined functions should implement [`Self::invoke_with_args`] instead. + #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")] fn invoke_no_args(&self, _number_rows: usize) -> Result { not_impl_err!( "Function {} does not implement invoke_no_args but called", @@ -767,6 +782,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { args: &[ColumnarValue], number_rows: usize, ) -> Result { + #[allow(deprecated)] self.inner.invoke_batch(args, number_rows) } diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 5df5d9c7dee22..bc20e0ff11c1f 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_8192", |b| { b.iter(|| { for _ in 0..iterations { + #[allow(deprecated)] // TODO: migrate to invoke_with_args black_box(random_func.invoke_batch(&[], 8192).unwrap()); } }) @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_128", |b| { b.iter(|| { for _ in 0..iterations_128 { + #[allow(deprecated)] // TODO: migrate to invoke_with_args black_box(random_func.invoke_batch(&[], 128).unwrap()); } }) diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 36cf07e9e5da2..eac0aa38f0583 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -121,6 +121,7 @@ mod test { #[tokio::test] async fn test_version_udf() { let version_udf = ScalarUDF::from(VersionFunc::new()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let version = version_udf.invoke_batch(&[], 1).unwrap(); if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version { diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index fef1eb9a60c82..5048b8fd47ec6 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -431,7 +431,7 @@ mod tests { use arrow::datatypes::{DataType, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use super::{adjust_to_local_time, ToLocalTimeFunc}; @@ -558,7 +558,11 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() - .invoke_batch(&[ColumnarValue::Scalar(input)], 1) + .invoke_with_args(ScalarFunctionArgs { + args: &[ColumnarValue::Scalar(input)], + number_rows: 1, + return_type: &expected.data_type(), + }) .unwrap(); match res { ColumnarValue::Scalar(res) => { @@ -617,6 +621,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = ToLocalTimeFunc::new() .invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size) .unwrap(); diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index f15fad701c554..78a7bf505dac1 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1008,7 +1008,7 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, Some(_)))); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let res = udf .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); @@ -1051,7 +1051,7 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let res = udf .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index dd90ce6a6c968..c291596c25200 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc { DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0] .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)? .cast_to(&DataType::Int64, None), + #[allow(deprecated)] // TODO: migrate to invoke_with_args DataType::Utf8 => ToTimestampSecondsFunc::new() .invoke_batch(args, batch_size)? .cast_to(&DataType::Int64, None), diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 9110f9f532d84..14b6dc3e054ed 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -277,7 +277,7 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let _ = LogFunc::new().invoke_batch(&args, 4); } @@ -286,7 +286,7 @@ mod tests { let args = [ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new().invoke_batch(&args, 1); result.expect_err("expected error"); } @@ -296,7 +296,7 @@ mod tests { let args = [ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -320,7 +320,7 @@ mod tests { let args = [ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -345,7 +345,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -370,7 +370,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -396,7 +396,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); @@ -425,7 +425,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); @@ -455,7 +455,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); @@ -485,7 +485,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index a24c613f52599..acf5f84df92b0 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -205,7 +205,7 @@ mod tests { ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = PowerFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function power"); @@ -232,7 +232,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = PowerFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function power"); diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index 7f21297712c73..33ff630f309ff 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -167,6 +167,7 @@ mod test { f32::NEG_INFINITY, ])); let batch_size = array.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = SignumFunc::new() .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); @@ -207,6 +208,7 @@ mod test { f64::NEG_INFINITY, ])); let batch_size = array.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = SignumFunc::new() .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8da154430fc55..819463795b7fd 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -655,7 +655,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], 1, @@ -670,7 +670,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], 1, @@ -685,7 +685,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], 1, @@ -711,7 +711,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -730,7 +730,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -749,7 +749,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -781,7 +781,7 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -802,7 +802,7 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -823,7 +823,7 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -905,7 +905,7 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -926,7 +926,7 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -947,7 +947,7 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 87180cb77de72..8b473500416b5 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -146,9 +146,10 @@ pub mod test { match expected { Ok(expected) => { assert_eq!(return_type.is_ok(), true); - assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + let return_type = return_type.unwrap(); + assert_eq!(return_type, $EXPECTED_DATA_TYPE); - let result = func.invoke_batch($ARGS, cardinality); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().clone().into_array(cardinality).expect("Failed to convert to array"); @@ -169,7 +170,7 @@ pub mod test { } else { // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_batch($ARGS, cardinality) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type.unwrap()}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 9bf168e8a1998..74d0ecdadd328 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -43,7 +43,7 @@ use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; -use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF}; +use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF}; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -141,7 +141,11 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?; // evaluate the function - let output = self.fun.invoke_batch(&inputs, batch.num_rows())?; + let output = self.fun.invoke_with_args(ScalarFunctionArgs { + args: inputs.as_slice(), + number_rows: batch.num_rows(), + return_type: &self.return_type, + })?; if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() {