diff --git a/ballista/rust/core/src/error.rs b/ballista/rust/core/src/error.rs index b2c8d99ae9f9..e9ffcd8180eb 100644 --- a/ballista/rust/core/src/error.rs +++ b/ballista/rust/core/src/error.rs @@ -139,7 +139,7 @@ impl From for BallistaError { } impl Display for BallistaError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match self { BallistaError::NotImplemented(ref desc) => { write!(f, "Not implemented: {}", desc) diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index bebc98f08cc4..619cc9bc925d 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -39,6 +39,7 @@ use datafusion::physical_plan::{ }; use async_trait::async_trait; +use datafusion::execution::runtime_env::RuntimeEnv; use futures::future; use futures::StreamExt; use log::{error, info}; @@ -99,7 +100,8 @@ impl ExecutionPlan for DistributedQueryExec { async fn execute( &self, partition: usize, - ) -> datafusion::error::Result { + _runtime: Arc, + ) -> Result { assert_eq!(0, partition); info!("Connecting to Ballista scheduler at {}", self.scheduler_url); diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index 6cdd8cc7665a..4d401eca03ff 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -28,11 +28,13 @@ use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::error::Result as ArrowResult; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::metrics::{ ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Metric, Partitioning, Statistics, + DisplayFormatType, ExecutionPlan, Metric, Partitioning, SendableRecordBatchStream, + Statistics, }; use datafusion::{ error::{DataFusionError, Result}, @@ -100,7 +102,8 @@ impl ExecutionPlan for ShuffleReaderExec { async fn execute( &self, partition: usize, - ) -> Result>> { + _runtime: Arc, + ) -> Result { info!("ShuffleReaderExec::execute({})", partition); let fetch_time = diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 6884720501fa..0962615d96f7 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -44,6 +44,8 @@ use datafusion::arrow::ipc::reader::FileReader; use datafusion::arrow::ipc::writer::FileWriter; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::common::IPCWriter; use datafusion::physical_plan::hash_utils::create_hashes; use datafusion::physical_plan::metrics::{ self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -51,7 +53,8 @@ use datafusion::physical_plan::metrics::{ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::Partitioning::RoundRobinBatch; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Metric, Partitioning, RecordBatchStream, Statistics, + DisplayFormatType, ExecutionPlan, Metric, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; use futures::StreamExt; use hashbrown::HashMap; @@ -139,10 +142,11 @@ impl ShuffleWriterExec { pub async fn execute_shuffle_write( &self, input_partition: usize, + runtime: Arc, ) -> Result> { let now = Instant::now(); - let mut stream = self.plan.execute(input_partition).await?; + let mut stream = self.plan.execute(input_partition, runtime).await?; let mut path = PathBuf::from(&self.work_dir); path.push(&self.job_id); @@ -197,7 +201,7 @@ impl ShuffleWriterExec { // we won't necessary produce output for every possible partition, so we // create writers on demand - let mut writers: Vec> = vec![]; + let mut writers: Vec> = vec![]; for _ in 0..num_output_partitions { writers.push(None); } @@ -265,7 +269,7 @@ impl ShuffleWriterExec { info!("Writing results to {}", path); let mut writer = - ShuffleWriter::new(path, stream.schema().as_ref())?; + IPCWriter::new(path, stream.schema().as_ref())?; writer.write(&output_batch)?; writers[output_partition] = Some(writer); @@ -350,9 +354,10 @@ impl ExecutionPlan for ShuffleWriterExec { async fn execute( &self, - input_partition: usize, - ) -> Result>> { - let part_loc = self.execute_shuffle_write(input_partition).await?; + partition: usize, + runtime: Arc, + ) -> Result { + let part_loc = self.execute_shuffle_write(partition, runtime).await?; // build metadata result batch let num_writers = part_loc.len(); @@ -432,55 +437,6 @@ fn result_schema() -> SchemaRef { ])) } -struct ShuffleWriter { - path: String, - writer: FileWriter, - num_batches: u64, - num_rows: u64, - num_bytes: u64, -} - -impl ShuffleWriter { - fn new(path: &str, schema: &Schema) -> Result { - let file = File::create(path) - .map_err(|e| { - BallistaError::General(format!( - "Failed to create partition file at {}: {:?}", - path, e - )) - }) - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.to_owned(), - writer: FileWriter::try_new(file, schema)?, - }) - } - - fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; - self.num_batches += 1; - self.num_rows += batch.num_rows() as u64; - let num_bytes: usize = batch - .columns() - .iter() - .map(|array| array.get_array_memory_size()) - .sum(); - self.num_bytes += num_bytes as u64; - Ok(()) - } - - fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(DataFusionError::ArrowError) - } - - fn path(&self) -> &str { - &self.path - } -} - #[cfg(test)] mod tests { use super::*; @@ -493,6 +449,8 @@ mod tests { #[tokio::test] async fn test() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let input_plan = Arc::new(CoalescePartitionsExec::new(create_input_plan()?)); let work_dir = TempDir::new()?; let query_stage = ShuffleWriterExec::try_new( @@ -502,7 +460,7 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0).await?; + let mut stream = query_stage.execute(0, runtime).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -545,6 +503,8 @@ mod tests { #[tokio::test] async fn test_partitioned() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let input_plan = create_input_plan()?; let work_dir = TempDir::new()?; let query_stage = ShuffleWriterExec::try_new( @@ -554,7 +514,7 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0).await?; + let mut stream = query_stage.execute(0, runtime).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; diff --git a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs index 6290add4e2b4..6de8dbab0a11 100644 --- a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs @@ -23,8 +23,9 @@ use crate::serde::scheduler::PartitionLocation; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, Statistics, + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion::{ error::{DataFusionError, Result}, @@ -102,7 +103,8 @@ impl ExecutionPlan for UnresolvedShuffleExec { async fn execute( &self, _partition: usize, - ) -> Result>> { + _runtime: Arc, + ) -> Result { Err(DataFusionError::Plan( "Ballista UnresolvedShuffleExec does not support execution".to_owned(), )) diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 3c05957987bb..cad27b315645 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -41,6 +41,7 @@ use datafusion::datasource::PartitionedFile; use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::{ window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, }; @@ -53,6 +54,7 @@ use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; +use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -72,7 +74,6 @@ use datafusion::physical_plan::{ limit::{GlobalLimitExec, LocalLimitExec}, projection::ProjectionExec, repartition::RepartitionExec, - sort::{SortExec, SortOptions}, Partitioning, }; use datafusion::physical_plan::{ @@ -626,6 +627,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env: Arc::new(RuntimeEnv::default()), }; let fun_expr = functions::create_physical_fun( diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index aca8f6459d23..a4f2f2ff6868 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod to_proto; mod roundtrip_tests { use std::{convert::TryInto, sync::Arc}; + use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{ arrow::{ compute::kernels::sort::SortOptions, @@ -36,7 +37,6 @@ mod roundtrip_tests { hash_aggregate::{AggregateMode, HashAggregateExec}, hash_join::{HashJoinExec, PartitionMode}, limit::{GlobalLimitExec, LocalLimitExec}, - sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, }, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 41484db57a7b..930f0757e202 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -29,7 +29,7 @@ use std::{ use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sort::SortExec; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{cross_join::CrossJoinExec, ColumnStatistics}; use datafusion::physical_plan::{ expressions::{ diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 2dfdb3d81181..bdb92c530f91 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -52,6 +52,7 @@ use datafusion::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::common::batch_byte_size; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{BinaryExpr, Column, Literal}; use datafusion::physical_plan::file_format::{CsvExec, ParquetExec}; @@ -59,7 +60,7 @@ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::HashAggregateExec; use datafusion::physical_plan::hash_join::HashJoinExec; use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sort::SortExec; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ metrics, AggregateExpr, ExecutionPlan, Metric, PhysicalExpr, RecordBatchStream, }; @@ -88,11 +89,7 @@ pub async fn write_stream_to_disk( while let Some(result) = stream.next().await { let batch = result?; - let batch_size_bytes: usize = batch - .columns() - .iter() - .map(|array| array.get_array_memory_size()) - .sum(); + let batch_size_bytes: usize = batch_byte_size(&batch); num_batches += 1; num_rows += batch.num_rows(); num_bytes += batch_size_bytes; diff --git a/ballista/rust/executor/src/collect.rs b/ballista/rust/executor/src/collect.rs index c3fadaed6645..12c26ef58730 100644 --- a/ballista/rust/executor/src/collect.rs +++ b/ballista/rust/executor/src/collect.rs @@ -27,6 +27,7 @@ use datafusion::arrow::{ datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, }; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; @@ -75,11 +76,12 @@ impl ExecutionPlan for CollectExec { async fn execute( &self, partition: usize, - ) -> Result>> { + runtime: Arc, + ) -> Result { assert_eq!(0, partition); let num_partitions = self.plan.output_partitioning().partition_count(); - let futures = (0..num_partitions).map(|i| self.plan.execute(i)); + let futures = (0..num_partitions).map(|i| self.plan.execute(i, runtime.clone())); let streams = futures::future::join_all(futures) .await .into_iter() diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs index 398ebca2b8e6..a7cd8ebe5e92 100644 --- a/ballista/rust/executor/src/executor.rs +++ b/ballista/rust/executor/src/executor.rs @@ -23,6 +23,7 @@ use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf; use datafusion::error::DataFusionError; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; @@ -71,7 +72,11 @@ impl Executor { )) }?; - let partitions = exec.execute_shuffle_write(part).await?; + let runtime_config = + RuntimeConfig::new().with_local_dirs(vec![self.work_dir.clone()]); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + + let partitions = exec.execute_shuffle_write(part, runtime).await?; println!( "=== [{}/{}/{}] Physical plan with metrics ===\n{}\n", diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 3291a62abe64..3d3884fd5021 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -254,7 +254,7 @@ mod test { use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::HashJoinExec; - use datafusion::physical_plan::sort::SortExec; + use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, }; diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs index 59fc69180368..ad2494c6aff2 100644 --- a/benchmarks/src/bin/nyctaxi.rs +++ b/benchmarks/src/bin/nyctaxi.rs @@ -116,13 +116,14 @@ async fn datafusion_sql_benchmarks( } async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Result<()> { + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); let plan = ctx.create_logical_plan(sql)?; let plan = ctx.optimize(&plan)?; if debug { println!("Optimized logical plan:\n{:?}", plan); } let physical_plan = ctx.create_physical_plan(&plan).await?; - let result = collect(physical_plan).await?; + let result = collect(physical_plan, runtime).await?; if debug { pretty::print_batches(&result)?; } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index d9317fe38dd3..b676253eab2d 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -263,6 +263,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Result) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { Self::Csv => write!(f, "csv"), Self::Tsv => write!(f, "tsv"), diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 46e2cbec56e2..bc37c7a0de20 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -77,10 +77,10 @@ rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.14", optional = true } +tempfile = "3" [dev-dependencies] criterion = "0.3" -tempfile = "3" doc-comment = "0.3" [[bench]] diff --git a/datafusion/benches/physical_plan.rs b/datafusion/benches/physical_plan.rs index 9222ae131b8f..e9eb53d69ef0 100644 --- a/datafusion/benches/physical_plan.rs +++ b/datafusion/benches/physical_plan.rs @@ -29,11 +29,12 @@ use arrow::{ }; use tokio::runtime::Runtime; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ collect, expressions::{col, PhysicalSortExpr}, memory::MemoryExec, - sort_preserving_merge::SortPreservingMergeExec, }; // Initialise the operator using the provided record batches and the sort key @@ -58,7 +59,8 @@ fn sort_preserving_merge_operator(batches: Vec, sort: &[&str]) { let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 8192)); let rt = Runtime::new().unwrap(); - rt.block_on(collect(merge)).unwrap(); + let rt_env = Arc::new(RuntimeEnv::default()); + rt.block_on(collect(merge, rt_env)).unwrap(); } // Produces `n` record batches of row size `m`. Each record batch will have diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index f3151d2d7140..8f409285b5ab 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -79,13 +79,15 @@ fn create_context() -> Arc> { let partitions = 16; rt.block_on(async { - let mem_table = MemTable::load(Arc::new(csv), 16 * 1024, Some(partitions)) - .await - .unwrap(); - // create local execution context let mut ctx = ExecutionContext::new(); ctx.state.lock().unwrap().config.target_partitions = 1; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + + let mem_table = + MemTable::load(Arc::new(csv), 16 * 1024, Some(partitions), runtime) + .await + .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) .unwrap(); ctx_holder.lock().unwrap().push(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index 515584b16c03..e1ae1743f94d 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -81,6 +81,7 @@ mod tests { }; use super::*; + use crate::execution::runtime_env::RuntimeEnv; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, TimestampMicrosecondArray, @@ -90,8 +91,9 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { let projection = None; + let runtime = Arc::new(RuntimeEnv::default()); let exec = get_exec("alltypes_plain.avro", &projection, 2, None).await?; - let stream = exec.execute(0).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches = stream .map(|batch| { diff --git a/datafusion/src/datasource/file_format/csv.rs b/datafusion/src/datasource/file_format/csv.rs index 337511316c51..99770a895d54 100644 --- a/datafusion/src/datasource/file_format/csv.rs +++ b/datafusion/src/datasource/file_format/csv.rs @@ -136,6 +136,7 @@ mod tests { use arrow::array::StringArray; use super::*; + use crate::execution::runtime_env::RuntimeEnv; use crate::{ datasource::{ file_format::PhysicalPlanConfig, @@ -149,10 +150,11 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]); let exec = get_exec("aggregate_test_100.csv", &projection, 2, None).await?; - let stream = exec.execute(0).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches: i32 = stream .map(|batch| { @@ -174,9 +176,10 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0, 1, 2, 3]); let exec = get_exec("aggregate_test_100.csv", &projection, 1024, Some(1)).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -219,10 +222,11 @@ mod tests { #[tokio::test] async fn read_char_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); let exec = get_exec("aggregate_test_100.csv", &projection, 1024, None).await?; - let batches = collect(exec).await.expect("Collect batches"); + let batches = collect(exec, runtime).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index b3fb1c4b464c..a8f11761f032 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -98,6 +98,7 @@ mod tests { use arrow::array::Int64Array; use super::*; + use crate::execution::runtime_env::RuntimeEnv; use crate::{ datasource::{ file_format::PhysicalPlanConfig, @@ -111,9 +112,10 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; let exec = get_exec(&projection, 2, None).await?; - let stream = exec.execute(0).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches: i32 = stream .map(|batch| { @@ -135,9 +137,10 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; let exec = get_exec(&projection, 1024, Some(1)).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -163,10 +166,11 @@ mod tests { #[tokio::test] async fn read_int_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); let exec = get_exec(&projection, 1024, None).await?; - let batches = collect(exec).await.expect("Collect batches"); + let batches = collect(exec, runtime).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index 7976be7913c8..608795a873cb 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -341,6 +341,7 @@ mod tests { }; use super::*; + use crate::execution::runtime_env::RuntimeEnv; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, TimestampNanosecondArray, @@ -349,9 +350,10 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; let exec = get_exec("alltypes_plain.parquet", &projection, 2, None).await?; - let stream = exec.execute(0).await?; + let stream = exec.execute(0, runtime).await?; let tt_batches = stream .map(|batch| { @@ -373,6 +375,7 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; let exec = get_exec("alltypes_plain.parquet", &projection, 1024, Some(1)).await?; @@ -380,7 +383,7 @@ mod tests { assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); assert!(exec.statistics().is_exact); - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -390,6 +393,7 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = None; let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; @@ -415,7 +419,7 @@ mod tests { y ); - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -426,10 +430,11 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![1]); let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -454,10 +459,11 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![0]); let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -479,10 +485,11 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![10]); let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -504,10 +511,11 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![6]); let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -532,10 +540,11 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![7]); let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -560,10 +569,11 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_parquet() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let projection = Some(vec![9]); let exec = get_exec("alltypes_plain.parquet", &projection, 1024, None).await?; - let batches = collect(exec).await?; + let batches = collect(exec, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index b47e7e12e54e..ada323139ff8 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -29,6 +29,7 @@ use async_trait::async_trait; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; @@ -65,6 +66,7 @@ impl MemTable { t: Arc, batch_size: usize, output_partitions: Option, + runtime: Arc, ) -> Result { let schema = t.schema(); let exec = t.scan(&None, batch_size, &[], None).await?; @@ -72,9 +74,10 @@ impl MemTable { let tasks = (0..partition_count) .map(|part_i| { + let runtime1 = runtime.clone(); let exec = exec.clone(); tokio::spawn(async move { - let stream = exec.execute(part_i).await?; + let stream = exec.execute(part_i, runtime1.clone()).await?; common::collect(stream).await }) }) @@ -101,7 +104,7 @@ impl MemTable { let mut output_partitions = vec![]; for i in 0..exec.output_partitioning().partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i).await?; + let mut stream = exec.execute(i, runtime.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -150,6 +153,7 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -171,7 +175,7 @@ mod tests { // scan with projection let exec = provider.scan(&Some(vec![2, 1]), 1024, &[], None).await?; - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -183,6 +187,7 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -201,7 +206,7 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; let exec = provider.scan(&None, 1024, &[], None).await?; - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); @@ -308,6 +313,7 @@ mod tests { #[tokio::test] async fn test_merged_schema() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let mut metadata = HashMap::new(); metadata.insert("foo".to_string(), "bar".to_string()); @@ -352,7 +358,7 @@ mod tests { MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?; let exec = provider.scan(&None, 1024, &[], None).await?; - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); diff --git a/datafusion/src/datasource/object_store/mod.rs b/datafusion/src/datasource/object_store/mod.rs index aece82ac2cf2..77ca1ef6bae7 100644 --- a/datafusion/src/datasource/object_store/mod.rs +++ b/datafusion/src/datasource/object_store/mod.rs @@ -171,7 +171,7 @@ pub struct ObjectStoreRegistry { } impl fmt::Debug for ObjectStoreRegistry { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("ObjectStoreRegistry") .field( "schemes", diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs index 6b6bb1381111..5e94e141688c 100644 --- a/datafusion/src/error.rs +++ b/datafusion/src/error.rs @@ -61,6 +61,9 @@ pub enum DataFusionError { /// Error returned during execution of the query. /// Examples include files not found, errors in parsing certain types. Execution(String), + /// This error is thrown when a consumer cannot acquire memory from the Memory Manager + /// we can just cancel the execution of the partition. + ResourcesExhausted(String), } impl DataFusionError { @@ -102,7 +105,7 @@ impl From for DataFusionError { } impl Display for DataFusionError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match *self { DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc), DataFusionError::ParquetError(ref desc) => { @@ -129,6 +132,9 @@ impl Display for DataFusionError { DataFusionError::Execution(ref desc) => { write!(f, "Execution error: {}", desc) } + DataFusionError::ResourcesExhausted(ref desc) => { + write!(f, "Resources exhausted: {}", desc) + } } } } diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 944284b96f97..89ccd7b2b938 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -76,6 +76,7 @@ use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; +use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::logical_plan::plan::Explain; use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::physical_plan::planner::DefaultPhysicalPlanner; @@ -178,6 +179,9 @@ impl ExecutionContext { .register_catalog(config.default_catalog.clone(), default_catalog); } + let runtime_env = + Arc::new(RuntimeEnv::new(config.runtime_config.clone()).unwrap()); + Self { state: Arc::new(Mutex::new(ExecutionContextState { catalog_list, @@ -187,6 +191,7 @@ impl ExecutionContext { config, execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env, })), } } @@ -713,6 +718,7 @@ impl ExecutionContext { let path = path.as_ref(); // create directory to contain the CSV files (one per partition) let fs_path = Path::new(path); + let runtime = self.state.lock().unwrap().runtime_env.clone(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -722,7 +728,7 @@ impl ExecutionContext { let path = fs_path.join(&filename); let file = fs::File::create(path)?; let mut writer = csv::Writer::new(file); - let stream = plan.execute(i).await?; + let stream = plan.execute(i, runtime.clone()).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| writer.write(&batch?)) @@ -752,6 +758,7 @@ impl ExecutionContext { let path = path.as_ref(); // create directory to contain the Parquet files (one per partition) let fs_path = Path::new(path); + let runtime = self.state.lock().unwrap().runtime_env.clone(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -765,7 +772,7 @@ impl ExecutionContext { plan.schema(), writer_properties.clone(), )?; - let stream = plan.execute(i).await?; + let stream = plan.execute(i, runtime.clone()).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| writer.write(&batch?)) @@ -893,6 +900,8 @@ pub struct ExecutionConfig { pub repartition_windows: bool, /// Should Datafusion parquet reader using the predicate to prune data parquet_pruning: bool, + /// Runtime configurations such as memory threshold and local disk for spill + pub runtime_config: RuntimeConfig, } impl Default for ExecutionConfig { @@ -927,6 +936,7 @@ impl Default for ExecutionConfig { repartition_aggregations: true, repartition_windows: true, parquet_pruning: true, + runtime_config: RuntimeConfig::default(), } } } @@ -1044,6 +1054,12 @@ impl ExecutionConfig { self.parquet_pruning = enabled; self } + + /// Customize runtime config + pub fn with_runtime_config(mut self, config: RuntimeConfig) -> Self { + self.runtime_config = config; + self + } } /// Holds per-execution properties and data (such as starting timestamps, etc). @@ -1093,6 +1109,8 @@ pub struct ExecutionContextState { pub execution_props: ExecutionProps, /// Object Store that are registered with the context pub object_store_registry: Arc, + /// Runtime environment + pub runtime_env: Arc, } impl Default for ExecutionContextState { @@ -1112,6 +1130,7 @@ impl ExecutionContextState { config: ExecutionConfig::new(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env: Arc::new(RuntimeEnv::default()), } } @@ -1378,7 +1397,8 @@ mod tests { let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - let results = collect_partitioned(physical_plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect_partitioned(physical_plan, runtime).await?; // note that the order of partitions is not deterministic let mut num_rows = 0; @@ -1426,6 +1446,7 @@ mod tests { let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = create_ctx(&tmp_dir, partition_count).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -1457,7 +1478,7 @@ mod tests { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan).await?; + let batches = collect(physical_plan, runtime).await?; assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); Ok(()) @@ -1533,7 +1554,8 @@ mod tests { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let batches = collect(physical_plan, runtime).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(4, batches[0].num_rows()); @@ -3307,7 +3329,8 @@ mod tests { let plan = ctx.optimize(&plan)?; let plan = ctx.create_physical_plan(&plan).await?; - let result = collect(plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let result = collect(plan, runtime).await?; let expected = vec![ "+-----+-----+-----------------+", diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 2887e29ada7e..f2d0385a3fe0 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -162,7 +162,8 @@ impl DataFrame for DataFrameImpl { /// execute it, collecting all resulting batches into memory async fn collect(&self) -> Result> { let plan = self.create_physical_plan().await?; - Ok(collect(plan).await?) + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + Ok(collect(plan, runtime).await?) } /// Print results. @@ -181,7 +182,8 @@ impl DataFrame for DataFrameImpl { /// execute it, returning a stream over a single partition async fn execute_stream(&self) -> Result { let plan = self.create_physical_plan().await?; - execute_stream(plan).await + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + execute_stream(plan, runtime).await } /// Convert the logical plan represented by this DataFrame into a physical plan and @@ -189,14 +191,16 @@ impl DataFrame for DataFrameImpl { /// partitioning async fn collect_partitioned(&self) -> Result>> { let plan = self.create_physical_plan().await?; - Ok(collect_partitioned(plan).await?) + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + Ok(collect_partitioned(plan, runtime).await?) } /// Convert the logical plan represented by this DataFrame into a physical plan and /// execute it, returning a stream for each partition async fn execute_stream_partitioned(&self) -> Result> { let plan = self.create_physical_plan().await?; - Ok(execute_stream_partitioned(plan).await?) + let runtime = self.ctx_state.lock().unwrap().runtime_env.clone(); + Ok(execute_stream_partitioned(plan, runtime).await?) } /// Returns the schema from the logical plan diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs new file mode 100644 index 000000000000..c4a6b1d6d899 --- /dev/null +++ b/datafusion/src/execution/disk_manager.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages files generated during query execution, files are +//! hashed among the directories listed in RuntimeConfig::local_dirs. + +use crate::error::{DataFusionError, Result}; +use log::info; +use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; +use std::collections::hash_map::DefaultHasher; +use std::fs::File; +use std::hash::{Hash, Hasher}; +use std::path::{Path, PathBuf}; +use tempfile::{Builder, TempDir}; + +/// Manages files generated during query execution, e.g. spill files generated +/// while processing dataset larger than available memory. +pub struct DiskManager { + local_dirs: Vec, +} + +impl DiskManager { + /// Create local dirs inside user provided dirs through conf + pub fn new(conf_dirs: &[String]) -> Result { + let local_dirs = create_local_dirs(conf_dirs)?; + info!( + "Created local dirs {:?} as DataFusion working directory", + local_dirs + ); + Ok(Self { local_dirs }) + } + + /// Create a file in conf dirs in randomized manner and return the file path + pub fn create_tmp_file(&self) -> Result { + create_tmp_file(&self.local_dirs) + } +} + +/// Setup local dirs by creating one new dir in each of the given dirs +fn create_local_dirs(local_dir: &[String]) -> Result> { + local_dir + .iter() + .map(|root| create_dir(root, "datafusion-")) + .collect() +} + +fn create_dir(root: &str, prefix: &str) -> Result { + Builder::new() + .prefix(prefix) + .tempdir_in(root) + .map_err(DataFusionError::IoError) +} + +fn get_file(file_name: &str, local_dirs: &[TempDir]) -> String { + let mut hasher = DefaultHasher::new(); + file_name.hash(&mut hasher); + let hash = hasher.finish(); + let dir = &local_dirs[hash.rem_euclid(local_dirs.len() as u64) as usize]; + let mut path = PathBuf::new(); + path.push(dir); + path.push(file_name); + path.to_str().unwrap().to_string() +} + +fn create_tmp_file(local_dirs: &[TempDir]) -> Result { + let name = rand_name(); + let mut path = get_file(&*name, local_dirs); + while Path::new(path.as_str()).exists() { + path = get_file(&rand_name(), local_dirs); + } + File::create(&path)?; + Ok(path) +} + +/// Return a random string suitable for use as a database name +fn rand_name() -> String { + thread_rng() + .sample_iter(&Alphanumeric) + .take(10) + .map(char::from) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::error::Result; + use crate::execution::disk_manager::{get_file, DiskManager}; + use tempfile::TempDir; + + #[test] + fn file_in_right_dir() -> Result<()> { + let local_dir1 = TempDir::new()?; + let local_dir2 = TempDir::new()?; + let local_dir3 = TempDir::new()?; + let local_dirs = vec![ + local_dir1.path().to_str().unwrap().to_string(), + local_dir2.path().to_str().unwrap().to_string(), + local_dir3.path().to_str().unwrap().to_string(), + ]; + + let dm = DiskManager::new(&local_dirs)?; + let actual = dm.create_tmp_file()?; + let name = actual.rsplit_once(std::path::MAIN_SEPARATOR).unwrap().1; + + let expected = get_file(name, &dm.local_dirs); + // file should be located in dir by it's name hash + assert_eq!(actual, expected); + Ok(()) + } +} diff --git a/datafusion/src/execution/memory_manager.rs b/datafusion/src/execution/memory_manager.rs new file mode 100644 index 000000000000..caa597bea603 --- /dev/null +++ b/datafusion/src/execution/memory_manager.rs @@ -0,0 +1,488 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages all available memory during query execution + +use crate::error::Result; +use async_trait::async_trait; +use hashbrown::HashMap; +use log::info; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Condvar, Mutex, Weak}; + +static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); + +fn next_id() -> usize { + CONSUMER_ID.fetch_add(1, Ordering::SeqCst) +} + +/// Type of the memory consumer +pub enum ConsumerType { + /// consumers that can grow its memory usage by requesting more from the memory manager or + /// shrinks its memory usage when we can no more assign available memory to it. + /// Examples are spillable sorter, spillable hashmap, etc. + Requesting, + /// consumers that are not spillable, counting in for only tracking purpose. + Tracking, +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +/// Id that uniquely identifies a Memory Consumer +pub struct MemoryConsumerId { + /// partition the consumer belongs to + pub partition_id: usize, + /// unique id + pub id: usize, +} + +impl MemoryConsumerId { + /// Auto incremented new Id + pub fn new(partition_id: usize) -> Self { + let id = next_id(); + Self { partition_id, id } + } +} + +impl Display for MemoryConsumerId { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}:{}", self.partition_id, self.id) + } +} + +#[async_trait] +/// A memory consumer that either takes up memory (of type `ConsumerType::Tracking`) +/// or grows/shrinks memory usage based on available memory (of type `ConsumerType::Requesting`). +pub trait MemoryConsumer: Send + Sync { + /// Display name of the consumer + fn name(&self) -> String; + + /// Unique id of the consumer + fn id(&self) -> &MemoryConsumerId; + + /// Ptr to MemoryManager + fn memory_manager(&self) -> Arc; + + /// Partition that the consumer belongs to + fn partition_id(&self) -> usize { + self.id().partition_id + } + + /// Type of the consumer + fn type_(&self) -> &ConsumerType; + + /// Grow memory by `required` to buffer more data in memory, + /// this may trigger spill before grow when the memory threshold is + /// reached for this consumer. + async fn try_grow(&self, required: usize) -> Result<()> { + let current = self.mem_used(); + info!( + "trying to acquire {} whiling holding {} from consumer {}", + human_readable_size(required), + human_readable_size(current), + self.id(), + ); + + let can_grow_directly = self + .memory_manager() + .can_grow_directly(required, current) + .await; + if !can_grow_directly { + info!( + "Failed to grow memory of {} directly from consumer {}, spilling first ...", + human_readable_size(required), + self.id() + ); + let freed = self.spill().await?; + self.memory_manager() + .record_free_then_acquire(freed, required); + } + Ok(()) + } + + /// Spill in-memory buffers to disk, free memory, return the previous used + async fn spill(&self) -> Result; + + /// Current memory used by this consumer + fn mem_used(&self) -> usize; +} + +impl Debug for dyn MemoryConsumer { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}[{}]: {}", + self.name(), + self.id(), + human_readable_size(self.mem_used()) + ) + } +} + +impl Display for dyn MemoryConsumer { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}[{}]", self.name(), self.id(),) + } +} + +/* +The memory management architecture is the following: + +1. User designates max execution memory by setting RuntimeConfig.max_memory and RuntimeConfig.memory_fraction (float64 between 0..1). + The actual max memory DataFusion could use `pool_size = max_memory * memory_fraction`. +2. The entities that take up memory during its execution are called 'Memory Consumers'. Operators or others are encouraged to + register themselves to the memory manager and report its usage through `mem_used()`. +3. There are two kinds of consumers: + - 'Requesting' consumers that would acquire memory during its execution and release memory through `spill` if no more memory is available. + - 'Tracking' consumers that exist for reporting purposes to provide a more accurate memory usage estimation for memory consumers. +4. Requesting and tracking consumers share the pool. Each controlling consumer could acquire a maximum of + (pool_size - all_tracking_used) / active_num_controlling_consumers. + + Memory Space for the DataFusion Lib / Process of `pool_size` + ┌──────────────────────────────────────────────z─────────────────────────────┐ + │ z │ + │ z │ + │ Requesting z Tracking │ + │ Memory Consumers z Memory Consumers │ + │ z │ + │ z │ + └──────────────────────────────────────────────z─────────────────────────────┘ +*/ + +/// Manage memory usage during physical plan execution +pub struct MemoryManager { + requesters: Arc>>>, + trackers: Arc>>>, + pool_size: usize, + requesters_total: Arc>, + cv: Condvar, +} + +impl MemoryManager { + /// Create new memory manager based on max available pool_size + #[allow(clippy::mutex_atomic)] + pub fn new(pool_size: usize) -> Self { + info!( + "Creating memory manager with initial size {}", + human_readable_size(pool_size) + ); + Self { + requesters: Arc::new(Mutex::new(HashMap::new())), + trackers: Arc::new(Mutex::new(HashMap::new())), + pool_size, + requesters_total: Arc::new(Mutex::new(0)), + cv: Condvar::new(), + } + } + + fn get_tracker_total(&self) -> usize { + let trackers = self.trackers.lock().unwrap(); + if trackers.len() > 0 { + trackers.values().fold(0usize, |acc, y| match y.upgrade() { + None => acc, + Some(t) => acc + t.mem_used(), + }) + } else { + 0 + } + } + + /// Register a new memory consumer for memory usage tracking + pub(crate) fn register_consumer(&self, consumer: &Arc) { + let id = consumer.id().clone(); + match consumer.type_() { + ConsumerType::Requesting => { + let mut requesters = self.requesters.lock().unwrap(); + requesters.insert(id, Arc::downgrade(consumer)); + } + ConsumerType::Tracking => { + let mut trackers = self.trackers.lock().unwrap(); + trackers.insert(id, Arc::downgrade(consumer)); + } + } + } + + fn max_mem_for_requesters(&self) -> usize { + let trk_total = self.get_tracker_total(); + self.pool_size - trk_total + } + + /// Grow memory attempt from a consumer, return if we could grant that much to it + async fn can_grow_directly(&self, required: usize, current: usize) -> bool { + let num_rqt = self.requesters.lock().unwrap().len(); + let mut rqt_current_used = self.requesters_total.lock().unwrap(); + let mut rqt_max = self.max_mem_for_requesters(); + + let granted; + loop { + let remaining = rqt_max - *rqt_current_used; + let max_per_rqt = rqt_max / num_rqt; + let min_per_rqt = max_per_rqt / 2; + + if required + current >= max_per_rqt { + granted = false; + break; + } + + if remaining >= required { + granted = true; + *rqt_current_used += required; + break; + } else if current < min_per_rqt { + // if we cannot acquire at lease 1/2n memory, just wait for others + // to spill instead spill self frequently with limited total mem + rqt_current_used = self.cv.wait(rqt_current_used).unwrap(); + } else { + granted = false; + break; + } + + rqt_max = self.max_mem_for_requesters(); + } + + granted + } + + fn record_free_then_acquire(&self, freed: usize, acquired: usize) { + let mut requesters_total = self.requesters_total.lock().unwrap(); + *requesters_total -= freed; + *requesters_total += acquired; + self.cv.notify_all() + } + + /// Drop a memory consumer from memory usage tracking + pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId) { + // find in requesters first + { + let mut requesters = self.requesters.lock().unwrap(); + if requesters.remove(id).is_some() { + return; + } + } + let mut trackers = self.trackers.lock().unwrap(); + trackers.remove(id); + } +} + +impl Display for MemoryManager { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let requesters = + self.requesters + .lock() + .unwrap() + .values() + .fold(vec![], |mut acc, consumer| match consumer.upgrade() { + None => acc, + Some(c) => { + acc.push(format!("{}", c)); + acc + } + }); + let tracker_mem = self.get_tracker_total(); + write!(f, + "MemoryManager usage statistics: total {}, tracker used {}, total {} requesters detail: \n {},", + human_readable_size(self.pool_size), + human_readable_size(tracker_mem), + &requesters.len(), + requesters.join("\n")) + } +} + +const TB: u64 = 1 << 40; +const GB: u64 = 1 << 30; +const MB: u64 = 1 << 20; +const KB: u64 = 1 << 10; + +fn human_readable_size(size: usize) -> String { + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{:.1} {}", value, unit) +} + +#[cfg(test)] +mod tests { + use crate::error::Result; + use crate::execution::memory_manager::{ + ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, + }; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use async_trait::async_trait; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + struct DummyRequester { + id: MemoryConsumerId, + runtime: Arc, + spills: AtomicUsize, + mem_used: AtomicUsize, + } + + impl DummyRequester { + fn new(partition: usize, runtime: Arc) -> Self { + Self { + id: MemoryConsumerId::new(partition), + runtime, + spills: AtomicUsize::new(0), + mem_used: AtomicUsize::new(0), + } + } + + async fn do_with_mem(&self, grow: usize) -> Result<()> { + self.try_grow(grow).await?; + self.mem_used.fetch_add(grow, Ordering::SeqCst); + Ok(()) + } + + fn get_spills(&self) -> usize { + self.spills.load(Ordering::SeqCst) + } + } + + #[async_trait] + impl MemoryConsumer for DummyRequester { + fn name(&self) -> String { + "dummy".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Requesting + } + + async fn spill(&self) -> Result { + self.spills.fetch_add(1, Ordering::SeqCst); + let used = self.mem_used.swap(0, Ordering::SeqCst); + Ok(used) + } + + fn mem_used(&self) -> usize { + self.mem_used.load(Ordering::SeqCst) + } + } + + struct DummyTracker { + id: MemoryConsumerId, + runtime: Arc, + mem_used: usize, + } + + impl DummyTracker { + fn new(partition: usize, runtime: Arc, mem_used: usize) -> Self { + Self { + id: MemoryConsumerId::new(partition), + runtime, + mem_used, + } + } + } + + #[async_trait] + impl MemoryConsumer for DummyTracker { + fn name(&self) -> String { + "dummy".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Tracking + } + + async fn spill(&self) -> Result { + Ok(0) + } + + fn mem_used(&self) -> usize { + self.mem_used + } + } + + #[tokio::test] + async fn basic_functionalities() -> Result<()> { + let config = RuntimeConfig::new() + .with_memory_fraction(1.0) + .with_max_execution_memory(100); + let runtime = Arc::new(RuntimeEnv::new(config)?); + + let tracker1 = Arc::new(DummyTracker::new(0, runtime.clone(), 5)); + runtime.register_consumer(&(tracker1.clone() as Arc)); + assert_eq!(runtime.memory_manager.get_tracker_total(), 5); + + let tracker2 = Arc::new(DummyTracker::new(0, runtime.clone(), 10)); + runtime.register_consumer(&(tracker2.clone() as Arc)); + assert_eq!(runtime.memory_manager.get_tracker_total(), 15); + + let tracker3 = Arc::new(DummyTracker::new(0, runtime.clone(), 15)); + runtime.register_consumer(&(tracker3.clone() as Arc)); + assert_eq!(runtime.memory_manager.get_tracker_total(), 30); + + runtime.drop_consumer(tracker2.id()); + assert_eq!(runtime.memory_manager.get_tracker_total(), 20); + + let requester1 = Arc::new(DummyRequester::new(0, runtime.clone())); + runtime.register_consumer(&(requester1.clone() as Arc)); + + // first requester entered, should be able to use any of the remaining 80 + requester1.do_with_mem(40).await?; + requester1.do_with_mem(10).await?; + assert_eq!(requester1.get_spills(), 0); + assert_eq!(requester1.mem_used(), 50); + assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 50); + + let requester2 = Arc::new(DummyRequester::new(0, runtime.clone())); + runtime.register_consumer(&(requester2.clone() as Arc)); + + requester2.do_with_mem(20).await?; + requester2.do_with_mem(30).await?; + assert_eq!(requester2.get_spills(), 1); + assert_eq!(requester2.mem_used(), 30); + + requester1.do_with_mem(10).await?; + assert_eq!(requester1.get_spills(), 1); + assert_eq!(requester1.mem_used(), 10); + + assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 40); + + Ok(()) + } +} diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index e353a3160b8d..ebc7c011970b 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -19,4 +19,7 @@ pub mod context; pub mod dataframe_impl; +pub(crate) mod disk_manager; +pub(crate) mod memory_manager; pub mod options; +pub mod runtime_env; diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs new file mode 100644 index 000000000000..1e1aecd33c1d --- /dev/null +++ b/datafusion/src/execution/runtime_env.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution runtime environment that tracks memory, disk and various configurations +//! that are used during physical plan execution. + +use crate::error::Result; +use crate::execution::disk_manager::DiskManager; +use crate::execution::memory_manager::{MemoryConsumer, MemoryConsumerId, MemoryManager}; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +#[derive(Clone)] +/// Execution runtime environment +pub struct RuntimeEnv { + /// Runtime configuration + pub config: RuntimeConfig, + /// Runtime memory management + pub memory_manager: Arc, + /// Manage temporary files during query execution + pub disk_manager: Arc, +} + +impl Debug for RuntimeEnv { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "RuntimeEnv") + } +} + +impl RuntimeEnv { + /// Create env based on configuration + pub fn new(config: RuntimeConfig) -> Result { + let memory_manager = Arc::new(MemoryManager::new( + (config.max_memory as f64 * config.memory_fraction) as usize, + )); + let disk_manager = Arc::new(DiskManager::new(&config.local_dirs)?); + Ok(Self { + config, + memory_manager, + disk_manager, + }) + } + + /// Get execution batch size based on config + pub fn batch_size(&self) -> usize { + self.config.batch_size + } + + /// Register the consumer to get it tracked + pub fn register_consumer(&self, memory_consumer: &Arc) { + self.memory_manager.register_consumer(memory_consumer); + } + + /// Drop the consumer from get tracked + pub fn drop_consumer(&self, id: &MemoryConsumerId) { + self.memory_manager.drop_consumer(id) + } +} + +impl Default for RuntimeEnv { + fn default() -> Self { + RuntimeEnv::new(RuntimeConfig::new()).unwrap() + } +} + +#[derive(Clone)] +/// Execution runtime configuration +pub struct RuntimeConfig { + /// Default batch size while creating new batches, it's especially useful + /// for buffer-in-memory batches since creating tiny batches would results + /// in too much metadata memory consumption. + pub batch_size: usize, + /// Max execution memory allowed for DataFusion. + /// Defaults to `usize::MAX` + pub max_memory: usize, + /// The fraction of total memory used for execution. + /// The purpose of this config is to set aside memory for untracked data structures, + /// and imprecise size estimation during memory acquisition. + /// Defaults to 0.7 + pub memory_fraction: f64, + /// Local dirs to store temporary files during execution. + pub local_dirs: Vec, +} + +impl RuntimeConfig { + /// New with default values + pub fn new() -> Self { + Default::default() + } + + /// Customize batch size + pub fn with_batch_size(mut self, n: usize) -> Self { + // batch size must be greater than zero + assert!(n > 0); + self.batch_size = n; + self + } + + /// Customize exec size + pub fn with_max_execution_memory(mut self, max_memory: usize) -> Self { + assert!(max_memory > 0); + self.max_memory = max_memory; + self + } + + /// Customize exec memory fraction + pub fn with_memory_fraction(mut self, fraction: f64) -> Self { + assert!(fraction > 0f64 && fraction <= 1f64); + self.memory_fraction = fraction; + self + } + + /// Customize exec size + pub fn with_local_dirs(mut self, local_dirs: Vec) -> Self { + assert!(!local_dirs.is_empty()); + self.local_dirs = local_dirs; + self + } +} + +impl Default for RuntimeConfig { + fn default() -> Self { + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().to_str().unwrap().to_string(); + std::mem::forget(tmp_dir); + + Self { + batch_size: 8192, + // Effectively "no limit" + max_memory: usize::MAX, + memory_fraction: 0.7, + local_dirs: vec![path], + } + } +} diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs index 9db720eac587..7b6471f64dd7 100644 --- a/datafusion/src/logical_plan/dfschema.rs +++ b/datafusion/src/logical_plan/dfschema.rs @@ -389,7 +389,7 @@ impl ToDFSchema for Vec { } impl Display for DFSchema { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!( f, "{}", diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index dadc16853074..00877dda48dc 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -169,7 +169,7 @@ impl FromStr for Column { } impl fmt::Display for Column { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.relation { Some(r) => write!(f, "#{}.{}", r, self.name), None => write!(f, "#{}", self.name), diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 952572f4dea3..b40dfc0103fc 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -1027,7 +1027,7 @@ pub enum PlanType { } impl fmt::Display for PlanType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { PlanType::InitialLogicalPlan => write!(f, "initial_logical_plan"), PlanType::OptimizedLogicalPlan { optimizer_name } => { diff --git a/datafusion/src/physical_optimizer/aggregate_statistics.rs b/datafusion/src/physical_optimizer/aggregate_statistics.rs index 515f7322bcae..4ae6ce3638cc 100644 --- a/datafusion/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/src/physical_optimizer/aggregate_statistics.rs @@ -259,6 +259,7 @@ mod tests { use arrow::record_batch::RecordBatch; use crate::error::Result; + use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::Operator; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; @@ -295,6 +296,7 @@ mod tests { nulls: bool, ) -> Result<()> { let conf = ExecutionConfig::new(); + let runtime = Arc::new(RuntimeEnv::default()); let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; let (col, count) = match nulls { @@ -304,7 +306,7 @@ mod tests { // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); - let result = common::collect(optimized.execute(0).await?).await?; + let result = common::collect(optimized.execute(0, runtime).await?).await?; assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); assert_eq!( result[0] diff --git a/datafusion/src/physical_plan/analyze.rs b/datafusion/src/physical_plan/analyze.rs index c9e316effcfb..0a810b915945 100644 --- a/datafusion/src/physical_plan/analyze.rs +++ b/datafusion/src/physical_plan/analyze.rs @@ -31,6 +31,7 @@ use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatc use futures::StreamExt; use super::{stream::RecordBatchReceiverStream, Distribution, SendableRecordBatchStream}; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; /// `EXPLAIN ANALYZE` execution plan operator. This operator runs its input, @@ -99,7 +100,11 @@ impl ExecutionPlan for AnalyzeExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "AnalyzeExec invalid partition. Expected 0, got {}", @@ -119,7 +124,7 @@ impl ExecutionPlan for AnalyzeExec { let (tx, rx) = tokio::sync::mpsc::channel(input_partitions); let captured_input = self.input.clone(); - let mut input_stream = captured_input.execute(0).await?; + let mut input_stream = captured_input.execute(0, runtime).await?; let captured_schema = self.schema.clone(); let verbose = self.verbose; @@ -236,6 +241,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -243,7 +249,7 @@ mod tests { let refs = blocking_exec.refs(); let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec, schema)); - let fut = collect(analyze_exec); + let fut = collect(analyze_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index 7397493c3a74..67ef2846e546 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -29,6 +29,7 @@ use crate::physical_plan::{ SendableRecordBatchStream, }; +use crate::execution::runtime_env::RuntimeEnv; use arrow::compute::kernels::concat::concat; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; @@ -111,9 +112,13 @@ impl ExecutionPlan for CoalesceBatchesExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { Ok(Box::pin(CoalesceBatchesStream { - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, schema: self.input.schema(), target_batch_size: self.target_batch_size, buffer: Vec::new(), @@ -351,9 +356,10 @@ mod tests { // execute and collect results let output_partition_count = exec.output_partitioning().partition_count(); let mut output_partitions = Vec::with_capacity(output_partition_count); + let runtime = Arc::new(RuntimeEnv::default()); for i in 0..output_partition_count { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i).await?; + let mut stream = exec.execute(i, runtime.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs b/datafusion/src/physical_plan/coalesce_partitions.rs index 089c6b4617aa..3fcacbb2f60a 100644 --- a/datafusion/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/src/physical_plan/coalesce_partitions.rs @@ -37,6 +37,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use super::SendableRecordBatchStream; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::spawn_execution; use pin_project_lite::pin_project; @@ -97,7 +98,11 @@ impl ExecutionPlan for CoalescePartitionsExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // CoalescePartitionsExec produces a single partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -113,7 +118,7 @@ impl ExecutionPlan for CoalescePartitionsExec { )), 1 => { // bypass any threading / metrics if there is a single partition - self.input.execute(0).await + self.input.execute(0, runtime).await } _ => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -136,6 +141,7 @@ impl ExecutionPlan for CoalescePartitionsExec { self.input.clone(), sender.clone(), part_i, + runtime.clone(), )); } @@ -215,6 +221,7 @@ mod tests { #[tokio::test] async fn merge() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let num_partitions = 4; @@ -244,7 +251,7 @@ mod tests { assert_eq!(merge.output_partitioning().partition_count(), 1); // the result should contain 4 batches (one per input partition) - let iter = merge.execute(0).await?; + let iter = merge.execute(0, runtime).await?; let batches = common::collect(iter).await?; assert_eq!(batches.len(), num_partitions); @@ -257,6 +264,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -265,7 +273,7 @@ mod tests { let coaelesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec); + let fut = collect(coaelesce_partitions_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index d6a37e0efa16..dd0c8248e459 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -19,17 +19,19 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::compute::concat; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; +use arrow::ipc::writer::FileWriter; use arrow::record_batch::RecordBatch; use futures::channel::mpsc; use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt}; use pin_project_lite::pin_project; use std::fs; -use std::fs::metadata; +use std::fs::{metadata, File}; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::task::JoinHandle; @@ -163,9 +165,10 @@ pub(crate) fn spawn_execution( input: Arc, mut output: mpsc::Sender>, partition: usize, + runtime: Arc, ) -> JoinHandle<()> { tokio::spawn(async move { - let mut stream = match input.execute(partition).await { + let mut stream = match input.execute(partition, runtime).await { Err(e) => { // If send fails, plan being torn // down, no place to send the error @@ -195,12 +198,7 @@ pub fn compute_record_batch_statistics( ) -> Statistics { let nb_rows = batches.iter().flatten().map(RecordBatch::num_rows).sum(); - let total_byte_size = batches - .iter() - .flatten() - .flat_map(RecordBatch::columns) - .map(|a| a.get_array_memory_size()) - .sum(); + let total_byte_size = batches.iter().flatten().map(batch_byte_size).sum(); let projection = match projection { Some(p) => p, @@ -376,3 +374,65 @@ mod tests { Ok(()) } } + +/// Write in Arrow IPC format. +pub struct IPCWriter { + /// path + pub path: String, + /// Inner writer + pub writer: FileWriter, + /// bathes written + pub num_batches: u64, + /// rows written + pub num_rows: u64, + /// bytes written + pub num_bytes: u64, +} + +impl IPCWriter { + /// Create new writer + pub fn new(path: &str, schema: &Schema) -> Result { + let file = File::create(path).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to create partition file at {}: {:?}", + path, e + )) + })?; + Ok(Self { + num_batches: 0, + num_rows: 0, + num_bytes: 0, + path: path.to_owned(), + writer: FileWriter::try_new(file, schema)?, + }) + } + + /// Write one single batch + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + self.num_batches += 1; + self.num_rows += batch.num_rows() as u64; + let num_bytes: usize = batch_byte_size(batch); + self.num_bytes += num_bytes as u64; + Ok(()) + } + + /// Finish the writer + pub fn finish(&mut self) -> Result<()> { + self.writer.finish().map_err(DataFusionError::ArrowError) + } + + /// Path write to + pub fn path(&self) -> &str { + &self.path + } +} + +/// Returns the total number of bytes of memory occupied physically by this batch. +pub fn batch_byte_size(batch: &RecordBatch) -> usize { + batch + .columns() + .iter() + .map(|array| array.get_array_memory_size()) + .sum() +} diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index a70d777ccf81..087507e1dece 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -42,6 +42,7 @@ use super::{ coalesce_batches::concat_batches, memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; +use crate::execution::runtime_env::RuntimeEnv; use log::debug; /// Data of the left side @@ -136,7 +137,11 @@ impl ExecutionPlan for CrossJoinExec { self.right.output_partitioning() } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // we only want to compute the build side once let left_data = { let mut build_side = self.build_side.lock().await; @@ -148,7 +153,7 @@ impl ExecutionPlan for CrossJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0).await?; + let stream = merge.execute(0, runtime.clone()).await?; // Load all batches and count the rows let (batches, num_rows) = stream @@ -173,7 +178,7 @@ impl ExecutionPlan for CrossJoinExec { } }; - let stream = self.right.execute(partition).await?; + let stream = self.right.execute(partition, runtime.clone()).await?; if left_data.num_rows() == 0 { return Ok(Box::pin(MemoryStream::try_new( diff --git a/datafusion/src/physical_plan/empty.rs b/datafusion/src/physical_plan/empty.rs index 46b50020fe0d..33a09d97bbe8 100644 --- a/datafusion/src/physical_plan/empty.rs +++ b/datafusion/src/physical_plan/empty.rs @@ -30,6 +30,7 @@ use arrow::record_batch::RecordBatch; use super::{common, SendableRecordBatchStream, Statistics}; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; /// Execution plan for empty relation (produces no rows) @@ -109,7 +110,11 @@ impl ExecutionPlan for EmptyExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -152,13 +157,14 @@ mod tests { #[tokio::test] async fn empty() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(false, schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results - let iter = empty.execute(0).await?; + let iter = empty.execute(0, runtime).await?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); @@ -183,21 +189,23 @@ mod tests { #[tokio::test] async fn invalid_execute() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(false, schema); // ask for the wrong partition - assert!(empty.execute(1).await.is_err()); - assert!(empty.execute(20).await.is_err()); + assert!(empty.execute(1, runtime.clone()).await.is_err()); + assert!(empty.execute(20, runtime.clone()).await.is_err()); Ok(()) } #[tokio::test] async fn produce_one_row() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(true, schema); - let iter = empty.execute(0).await?; + let iter = empty.execute(0, runtime).await?; let batches = common::collect(iter).await?; // should have one item diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index 74093259aaf6..df3dc98f196d 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -31,6 +31,7 @@ use crate::{ use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use super::SendableRecordBatchStream; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; /// Explain execution plan operator. This operator contains the string @@ -101,7 +102,11 @@ impl ExecutionPlan for ExplainExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "ExplainExec invalid partition {}", diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index a85d86708557..9ba7eaaea343 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -113,7 +113,7 @@ pub struct PhysicalSortExpr { } impl std::fmt::Display for PhysicalSortExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let opts_string = match (self.options.descending, self.options.nulls_first) { (true, true) => "DESC", (true, false) => "DESC NULLS LAST", diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index b50c0a082686..6ab5bf9e8c79 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -26,6 +26,7 @@ use arrow::datatypes::SchemaRef; #[cfg(feature = "avro")] use arrow::error::ArrowError; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use std::any::Any; use std::sync::Arc; @@ -92,14 +93,22 @@ impl ExecutionPlan for AvroExec { } #[cfg(not(feature = "avro"))] - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Err(DataFusionError::NotImplemented( "Cannot execute avro plan without avro feature enabled".to_string(), )) } #[cfg(feature = "avro")] - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { let proj = self.base_config.projected_file_column_names(); let batch_size = self.base_config.batch_size; diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index f250baa1b36c..ea965a419560 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -27,6 +27,7 @@ use arrow::datatypes::SchemaRef; use std::any::Any; use std::sync::Arc; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use super::file_stream::{BatchIter, FileStream}; @@ -106,7 +107,11 @@ impl ExecutionPlan for CsvExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { let batch_size = self.base_config.batch_size; let file_schema = Arc::clone(&self.base_config.file_schema); let file_projection = self.base_config.file_column_projection_indices(); @@ -175,6 +180,7 @@ mod tests { #[tokio::test] async fn csv_exec_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -197,7 +203,7 @@ mod tests { assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); - let mut stream = csv.execute(0).await?; + let mut stream = csv.execute(0, runtime).await?; let batch = stream.next().await.unwrap()?; assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); @@ -221,6 +227,7 @@ mod tests { #[tokio::test] async fn csv_exec_with_limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -243,7 +250,7 @@ mod tests { assert_eq!(13, csv.projected_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); - let mut it = csv.execute(0).await?; + let mut it = csv.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); @@ -267,6 +274,7 @@ mod tests { #[tokio::test] async fn csv_exec_with_partition() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -296,7 +304,7 @@ mod tests { assert_eq!(2, csv.projected_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); - let mut it = csv.execute(0).await?; + let mut it = csv.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index 9032eb9d5e5d..a0959c23b657 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -19,6 +19,7 @@ use async_trait::async_trait; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; @@ -82,7 +83,11 @@ impl ExecutionPlan for NdJsonExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { let proj = self.base_config.projected_file_column_names(); let batch_size = self.base_config.batch_size; @@ -154,6 +159,7 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_without_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); use arrow::datatypes::DataType; let path = format!("{}/1.json", TEST_DATA_BASE); let exec = NdJsonExec::new(PhysicalPlanConfig { @@ -191,7 +197,7 @@ mod tests { &DataType::Utf8 ); - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 3); @@ -209,6 +215,7 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let path = format!("{}/1.json", TEST_DATA_BASE); let exec = NdJsonExec::new(PhysicalPlanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -228,7 +235,7 @@ mod tests { inferred_schema.field_with_name("c").unwrap(); inferred_schema.field_with_name("d").unwrap_err(); - let mut it = exec.execute(0).await?; + let mut it = exec.execute(0, runtime).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 4); diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 355a98c90e95..78c9428e39db 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -59,6 +59,7 @@ use tokio::{ task, }; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use super::PartitionColumnProjector; @@ -186,7 +187,11 @@ impl ExecutionPlan for ParquetExec { } } - async fn execute(&self, partition_index: usize) -> Result { + async fn execute( + &self, + partition_index: usize, + _runtime: Arc, + ) -> Result { // because the parquet implementation is not thread-safe, it is necessary to execute // on a thread and communicate with channels let (response_tx, response_rx): ( @@ -478,6 +483,7 @@ mod tests { #[tokio::test] async fn parquet_exec_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let parquet_exec = ParquetExec::new( @@ -497,7 +503,7 @@ mod tests { ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0).await?; + let mut results = parquet_exec.execute(0, runtime).await?; let batch = results.next().await.unwrap()?; assert_eq!(8, batch.num_rows()); @@ -522,6 +528,7 @@ mod tests { #[tokio::test] async fn parquet_exec_with_partition() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let mut partitioned_file = local_unpartitioned_file(filename.clone()); @@ -552,7 +559,7 @@ mod tests { ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0).await?; + let mut results = parquet_exec.execute(0, runtime).await?; let batch = results.next().await.unwrap()?; let expected = vec![ "+----+----------+-------------+-------+", diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index a32371a1e481..dab5facfd905 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -37,6 +37,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use crate::execution::runtime_env::RuntimeEnv; use futures::stream::{Stream, StreamExt}; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to @@ -118,13 +119,17 @@ impl ExecutionPlan for FilterExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { schema: self.input.schema().clone(), predicate: self.predicate.clone(), - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, baseline_metrics, })) } @@ -234,6 +239,7 @@ mod tests { #[tokio::test] async fn simple_predicate() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -275,7 +281,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?); - let results = collect(filter).await?; + let results = collect(filter, runtime).await?; results .iter() diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index df073b62c5b7..6bff76f010d6 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -1547,7 +1547,7 @@ pub struct ScalarFunctionExpr { } impl Debug for ScalarFunctionExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("ScalarFunctionExpr") .field("fun", &"") .field("name", &self.name) diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 4698ba5dbb0d..f15e8f0fb47e 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -48,6 +48,7 @@ use arrow::{ use hashbrown::raw::RawTable; use pin_project_lite::pin_project; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use super::common::AbortOnDropSingle; @@ -207,8 +208,12 @@ impl ExecutionPlan for HashAggregateExec { self.input.output_partitioning() } - async fn execute(&self, partition: usize) -> Result { - let input = self.input.execute(partition).await?; + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + let input = self.input.execute(partition, runtime).await?; let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -614,7 +619,7 @@ struct Accumulators { } impl std::fmt::Debug for Accumulators { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { // hashes are not store inline, so could only get values let map_string = "RawTable"; f.debug_struct("Accumulators") @@ -1061,6 +1066,8 @@ mod tests { DataType::Float64, ))]; + let runtime = Arc::new(RuntimeEnv::default()); + let partial_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, groups.clone(), @@ -1069,7 +1076,8 @@ mod tests { input_schema.clone(), )?); - let result = common::collect(partial_aggregate.execute(0).await?).await?; + let result = + common::collect(partial_aggregate.execute(0, runtime.clone()).await?).await?; let expected = vec![ "+---+---------------+-------------+", @@ -1100,7 +1108,8 @@ mod tests { input_schema, )?); - let result = common::collect(merged_aggregate.execute(0).await?).await?; + let result = + common::collect(merged_aggregate.execute(0, runtime.clone()).await?).await?; assert_eq!(result.len(), 1); let batch = &result[0]; @@ -1161,7 +1170,11 @@ mod tests { ))) } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { let stream; if self.yield_first { stream = TestYieldingStream::New; @@ -1237,6 +1250,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1258,7 +1272,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec); + let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1270,6 +1284,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel_with_groups() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), @@ -1294,7 +1309,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec); + let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 8cb2f44db281..39479f9485e5 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -70,6 +70,7 @@ use super::{ }; use crate::arrow::array::BooleanBufferBuilder; use crate::arrow::datatypes::TimeUnit; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; use log::debug; @@ -90,7 +91,7 @@ use std::fmt; struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } } @@ -277,7 +278,11 @@ impl ExecutionPlan for HashJoinExec { self.right.output_partitioning() } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); // we only want to compute the build side once for PartitionMode::CollectLeft let left_data = { @@ -292,7 +297,7 @@ impl ExecutionPlan for HashJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0).await?; + let stream = merge.execute(0, runtime.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -345,7 +350,7 @@ impl ExecutionPlan for HashJoinExec { let start = Instant::now(); // Load 1 partition of left side in memory - let stream = self.left.execute(partition).await?; + let stream = self.left.execute(partition, runtime.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -396,7 +401,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let right_stream = self.right.execute(partition).await?; + let right_stream = self.right.execute(partition, runtime.clone()).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let num_rows = left_data.1.num_rows(); @@ -1084,11 +1089,12 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, + runtime: Arc, ) -> Result<(Vec, Vec)> { let join = join(left, right, on, join_type, null_equals_null)?; let columns = columns(&join.schema()); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; Ok((columns, batches)) @@ -1100,6 +1106,7 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, + runtime: Arc, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1132,7 +1139,7 @@ mod tests { let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i).await?; + let stream = join.execute(i, runtime.clone()).await?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1147,6 +1154,7 @@ mod tests { #[tokio::test] async fn join_inner_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1169,6 +1177,7 @@ mod tests { on.clone(), &JoinType::Inner, false, + runtime, ) .await?; @@ -1190,6 +1199,7 @@ mod tests { #[tokio::test] async fn partitioned_join_inner_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1211,6 +1221,7 @@ mod tests { on.clone(), &JoinType::Inner, false, + runtime, ) .await?; @@ -1232,6 +1243,7 @@ mod tests { #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1248,7 +1260,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false).await?; + join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1269,6 +1281,7 @@ mod tests { #[tokio::test] async fn join_inner_two() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1291,7 +1304,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false).await?; + join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1315,6 +1328,7 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] async fn join_inner_one_two_parts_left() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1344,7 +1358,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false).await?; + join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1368,6 +1382,7 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] async fn join_inner_one_two_parts_right() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1397,7 +1412,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime.clone()).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); @@ -1411,7 +1426,7 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); // second part - let stream = join.execute(1).await?; + let stream = join.execute(1, runtime.clone()).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1442,6 +1457,7 @@ mod tests { #[tokio::test] async fn join_left_multi_batch() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1462,7 +1478,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1482,6 +1498,7 @@ mod tests { #[tokio::test] async fn join_full_multi_batch() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1503,7 +1520,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1525,6 +1542,7 @@ mod tests { #[tokio::test] async fn join_left_empty_right() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1542,7 +1560,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1560,6 +1578,7 @@ mod tests { #[tokio::test] async fn join_full_empty_right() { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1577,7 +1596,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await.unwrap(); + let stream = join.execute(0, runtime).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1595,6 +1614,7 @@ mod tests { #[tokio::test] async fn join_left_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1616,6 +1636,7 @@ mod tests { on.clone(), &JoinType::Left, false, + runtime, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1636,6 +1657,7 @@ mod tests { #[tokio::test] async fn partitioned_join_left_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1657,6 +1679,7 @@ mod tests { on.clone(), &JoinType::Left, false, + runtime, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1677,6 +1700,7 @@ mod tests { #[tokio::test] async fn join_semi() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right @@ -1697,7 +1721,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1716,6 +1740,7 @@ mod tests { #[tokio::test] async fn join_anti() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 2, 3, 5]), ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right @@ -1736,7 +1761,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1753,6 +1778,7 @@ mod tests { #[tokio::test] async fn join_right_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1769,7 +1795,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false).await?; + join_collect(left, right, on, &JoinType::Right, false, runtime).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1790,6 +1816,7 @@ mod tests { #[tokio::test] async fn partitioned_join_right_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1806,7 +1833,8 @@ mod tests { )]; let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false).await?; + partitioned_join_collect(left, right, on, &JoinType::Right, false, runtime) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1827,6 +1855,7 @@ mod tests { #[tokio::test] async fn join_full_one() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1847,7 +1876,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1917,6 +1946,7 @@ mod tests { #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let left = build_table( ("a", &vec![1, 2, 3]), ("b", &vec![4, 5, 7]), @@ -1938,7 +1968,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); - let stream = join.execute(0).await?; + let stream = join.execute(0, runtime).await?; let batches = common::collect(stream).await?; let expected = vec![ diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index ef492ec18320..3e096364e62b 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -40,6 +40,7 @@ use super::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; /// Limit execution plan @@ -113,7 +114,11 @@ impl ExecutionPlan for GlobalLimitExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -130,7 +135,7 @@ impl ExecutionPlan for GlobalLimitExec { } let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(0).await?; + let stream = self.input.execute(0, runtime).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -242,9 +247,13 @@ impl ExecutionPlan for LocalLimitExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(partition).await?; + let stream = self.input.execute(partition, runtime).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -392,6 +401,7 @@ mod tests { #[tokio::test] async fn limit() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let num_partitions = 4; @@ -420,7 +430,7 @@ mod tests { GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), 7); // the result should contain 4 batches (one per input partition) - let iter = limit.execute(0).await?; + let iter = limit.execute(0, runtime).await?; let batches = common::collect(iter).await?; // there should be a total of 100 rows diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index e2e6221cada6..15848c558916 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -31,6 +31,7 @@ use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use futures::Stream; @@ -47,7 +48,7 @@ pub struct MemoryExec { } impl fmt::Debug for MemoryExec { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "partitions: [...]")?; write!(f, "schema: {:?}", self.projected_schema)?; write!(f, "projection: {:?}", self.projection) @@ -86,7 +87,11 @@ impl ExecutionPlan for MemoryExec { ))) } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), self.projected_schema.clone(), @@ -252,6 +257,7 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let (schema, batch) = mock_data()?; let executor = MemoryExec::try_new(&[vec![batch]], schema, Some(vec![2, 1]))?; @@ -277,7 +283,7 @@ mod tests { ); // scan with projection - let mut it = executor.execute(0).await?; + let mut it = executor.execute(0, runtime).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -289,6 +295,7 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let (schema, batch) = mock_data()?; let executor = MemoryExec::try_new(&[vec![batch]], schema, None)?; @@ -325,7 +332,7 @@ mod tests { ]) ); - let mut it = executor.execute(0).await?; + let mut it = executor.execute(0, runtime).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(4, batch1.schema().fields().len()); assert_eq!(4, batch1.num_columns()); diff --git a/datafusion/src/physical_plan/metrics/mod.rs b/datafusion/src/physical_plan/metrics/mod.rs index 7c59c8dddd76..089550cee5cf 100644 --- a/datafusion/src/physical_plan/metrics/mod.rs +++ b/datafusion/src/physical_plan/metrics/mod.rs @@ -76,7 +76,7 @@ pub struct Metric { } impl Display for Metric { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", self.value.name())?; let mut iter = self @@ -282,7 +282,7 @@ impl MetricsSet { impl Display for MetricsSet { /// format the MetricsSet as a single string - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let mut is_first = true; for i in self.metrics.iter() { if !is_first { @@ -363,7 +363,7 @@ impl Label { } impl Display for Label { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}={}", self.name, self.value) } } diff --git a/datafusion/src/physical_plan/metrics/value.rs b/datafusion/src/physical_plan/metrics/value.rs index 1caf13ee724c..6944aab3b0ab 100644 --- a/datafusion/src/physical_plan/metrics/value.rs +++ b/datafusion/src/physical_plan/metrics/value.rs @@ -45,7 +45,7 @@ impl PartialEq for Count { } impl Display for Count { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", self.value()) } } @@ -97,7 +97,7 @@ impl PartialEq for Time { } impl Display for Time { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let duration = std::time::Duration::from_nanos(self.value() as u64); write!(f, "{:?}", duration) } @@ -216,7 +216,7 @@ impl PartialEq for Timestamp { } impl Display for Timestamp { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self.value() { None => write!(f, "NONE"), Some(v) => { @@ -416,7 +416,7 @@ impl MetricValue { impl std::fmt::Display for MetricValue { /// Prints the value of this metric - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { Self::OutputRows(count) | Self::Count { count, .. } => { write!(f, "{}", count) diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index d39a7a006663..836d3994343f 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -25,6 +25,7 @@ use self::{ use crate::physical_plan::expressions::PhysicalSortExpr; use crate::{ error::{DataFusionError, Result}, + execution::runtime_env::RuntimeEnv, scalar::ScalarValue, }; use arrow::compute::kernels::partition::lexicographical_partition_ranges; @@ -154,7 +155,11 @@ pub trait ExecutionPlan: Debug + Send + Sync { ) -> Result>; /// creates an iterator - async fn execute(&self, partition: usize) -> Result; + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result; /// Return a snapshot of the set of [`Metric`]s for this /// [`ExecutionPlan`]. @@ -310,24 +315,28 @@ pub fn visit_execution_plan( } /// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect(plan: Arc) -> Result> { - let stream = execute_stream(plan).await?; +pub async fn collect( + plan: Arc, + runtime: Arc, +) -> Result> { + let stream = execute_stream(plan, runtime).await?; common::collect(stream).await } /// Execute the [ExecutionPlan] and return a single stream of results pub async fn execute_stream( plan: Arc, + runtime: Arc, ) -> Result { match plan.output_partitioning().partition_count() { 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0).await, + 1 => plan.execute(0, runtime).await, _ => { // merge into a single partition let plan = CoalescePartitionsExec::new(plan.clone()); // CoalescePartitionsExec must produce a single partition assert_eq!(1, plan.output_partitioning().partition_count()); - plan.execute(0).await + plan.execute(0, runtime).await } } } @@ -335,8 +344,9 @@ pub async fn execute_stream( /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect_partitioned( plan: Arc, + runtime: Arc, ) -> Result>> { - let streams = execute_stream_partitioned(plan).await?; + let streams = execute_stream_partitioned(plan, runtime).await?; let mut batches = Vec::with_capacity(streams.len()); for stream in streams { batches.push(common::collect(stream).await?); @@ -347,11 +357,12 @@ pub async fn collect_partitioned( /// Execute the [ExecutionPlan] and return a vec with one stream per output partition pub async fn execute_stream_partitioned( plan: Arc, + runtime: Arc, ) -> Result> { let num_partitions = plan.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(num_partitions); for i in 0..num_partitions { - streams.push(plan.execute(i).await?); + streams.push(plan.execute(i, runtime.clone()).await?); } Ok(streams) } @@ -651,8 +662,7 @@ pub mod projection; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; pub mod repartition; -pub mod sort; -pub mod sort_preserving_merge; +pub mod sorts; pub mod stream; pub mod string_expressions; pub mod type_coercion; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 6d913ac0f27c..1ab78b24a482 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -45,7 +45,7 @@ use crate::physical_plan::hash_join::HashJoinExec; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sort::SortExec; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{join_utils, Partitioning}; @@ -1461,6 +1461,7 @@ mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; use crate::execution::options::CsvReadOptions; + use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::plan::Extension; use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ @@ -1920,7 +1921,11 @@ mod tests { unimplemented!("NoOpExecutionPlan::with_new_children"); } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index 98317b3ff487..d86548b2e217 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -37,6 +37,7 @@ use arrow::record_batch::RecordBatch; use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; +use crate::execution::runtime_env::RuntimeEnv; use async_trait::async_trait; use futures::stream::Stream; use futures::stream::StreamExt; @@ -136,11 +137,15 @@ impl ExecutionPlan for ProjectionExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { Ok(Box::pin(ProjectionStream { schema: self.schema.clone(), expr: self.expr.iter().map(|x| x.0.clone()).collect(), - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) } @@ -293,6 +298,7 @@ mod tests { #[tokio::test] async fn project_first_column() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -329,7 +335,7 @@ mod tests { let mut row_count = 0; for partition in 0..projection.output_partitioning().partition_count() { partition_count += 1; - let stream = projection.execute(partition).await?; + let stream = projection.execute(partition, runtime.clone()).await?; row_count += stream .map(|batch| { diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index a3a5b0618a9e..5549794ed9b8 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -36,6 +36,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; +use crate::execution::runtime_env::RuntimeEnv; use futures::stream::Stream; use futures::StreamExt; use hashbrown::HashMap; @@ -164,7 +165,11 @@ impl ExecutionPlan for RepartitionExec { self.partitioning.clone() } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { // lock mutexes let mut state = self.state.lock().await; @@ -207,6 +212,7 @@ impl ExecutionPlan for RepartitionExec { txs.clone(), self.partitioning.clone(), r_metrics, + runtime.clone(), )); // In a separate task, wait for each input to be done @@ -285,12 +291,13 @@ impl RepartitionExec { mut txs: HashMap>>>, partitioning: Partitioning, r_metrics: RepartitionMetrics, + runtime: Arc, ) -> Result<()> { let num_output_partitions = txs.len(); // execute the child operator let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i).await?; + let mut stream = input.execute(i, runtime).await?; timer.done(); let mut counter = 0; @@ -616,6 +623,7 @@ mod tests { input_partitions: Vec>, partitioning: Partitioning, ) -> Result>> { + let runtime = Arc::new(RuntimeEnv::default()); // create physical plan let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; @@ -624,7 +632,7 @@ mod tests { let mut output_partitions = vec![]; for i in 0..exec.partitioning.partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i).await?; + let mut stream = exec.execute(i, runtime.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -664,6 +672,7 @@ mod tests { #[tokio::test] async fn unsupported_partitioning() { + let runtime = Arc::new(RuntimeEnv::default()); // have to send at least one batch through to provoke error let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", @@ -678,7 +687,7 @@ mod tests { // returned and no results produced let partitioning = Partitioning::UnknownPartitioning(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -698,13 +707,14 @@ mod tests { // This generates an error on a call to execute. The error // should be returned and no results produced. + let runtime = Arc::new(RuntimeEnv::default()); let input = ErrorExec::new(); let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -720,6 +730,7 @@ mod tests { #[tokio::test] async fn repartition_with_error_in_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, @@ -737,7 +748,7 @@ mod tests { // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -753,6 +764,7 @@ mod tests { #[tokio::test] async fn repartition_with_delayed_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let batch1 = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, @@ -787,7 +799,7 @@ mod tests { assert_batches_sorted_eq!(&expected, &expected_batches); - let output_stream = exec.execute(0).await.unwrap(); + let output_stream = exec.execute(0, runtime).await.unwrap(); let batches = crate::physical_plan::common::collect(output_stream) .await .unwrap(); @@ -797,6 +809,7 @@ mod tests { #[tokio::test] async fn robin_repartition_with_dropping_output_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let partitioning = Partitioning::RoundRobinBatch(2); // The barrier exec waits to be pinged // requires the input to wait at least once) @@ -805,8 +818,8 @@ mod tests { // partition into two output streams let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0).await.unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced @@ -839,6 +852,7 @@ mod tests { // wiht different compilers, we will compare the same execution with // and without droping the output stream. async fn hash_repartition_with_dropping_output_stream() { + let runtime = Arc::new(RuntimeEnv::default()); let partitioning = Partitioning::Hash( vec![Arc::new(crate::physical_plan::expressions::Column::new( "my_awesome_field", @@ -850,7 +864,7 @@ mod tests { // We first collect the results without droping the output stream. let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); input.wait().await; let batches_without_drop = crate::physical_plan::common::collect(output_stream1) .await @@ -870,8 +884,8 @@ mod tests { // Now do the same but dropping the stream before waiting for the barrier let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0).await.unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); @@ -936,6 +950,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -946,7 +961,7 @@ mod tests { Partitioning::UnknownPartitioning(1), )?); - let fut = collect(repartition_exec); + let fut = collect(repartition_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -958,6 +973,7 @@ mod tests { #[tokio::test] async fn hash_repartition_avoid_empty_batch() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let batch = RecordBatch::try_from_iter(vec![( "a", Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, @@ -972,11 +988,11 @@ mod tests { let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch)], schema); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream0 = exec.execute(0).await.unwrap(); + let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); let batch0 = crate::physical_plan::common::collect(output_stream0) .await .unwrap(); - let output_stream1 = exec.execute(1).await.unwrap(); + let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); let batch1 = crate::physical_plan::common::collect(output_stream1) .await .unwrap(); diff --git a/datafusion/src/physical_plan/sorts/external_sort.rs b/datafusion/src/physical_plan/sorts/external_sort.rs new file mode 100644 index 000000000000..8550cb5ad433 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/external_sort.rs @@ -0,0 +1,659 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the External-Sort plan + +use crate::error::{DataFusionError, Result}; +use crate::execution::memory_manager::{ + ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, +}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use crate::physical_plan::sorts::in_mem_sort::InMemSortStream; +use crate::physical_plan::sorts::sort::sort_batch; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; +use crate::physical_plan::sorts::SortedStream; +use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::{ + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, +}; +use arrow::datatypes::SchemaRef; +use arrow::error::Result as ArrowResult; +use arrow::ipc::reader::FileReader; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use futures::lock::Mutex; +use futures::StreamExt; +use log::{error, info}; +use std::any::Any; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::io::BufReader; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver as TKReceiver, Sender as TKSender}; +use tokio::task; + +/// Sort arbitrary size of data to get an total order (may spill several times during sorting based on free memory available). +/// +/// The basic architecture of the algorithm: +/// +/// let spills = vec![]; +/// let in_mem_batches = vec![]; +/// while (input.has_next()) { +/// let batch = input.next(); +/// // no enough memory available, spill first. +/// if exec_memory_available < size_of(batch) { +/// let ordered_stream = in_mem_heap_sort(in_mem_batches.drain(..)); +/// let tmp_file = spill_write(ordered_stream); +/// spills.push(tmp_file); +/// } +/// // sort the batch while it's probably still in cache and buffer it. +/// let sorted = sort_by_key(batch); +/// in_mem_batches.push(sorted); +/// } +/// +/// let partial_ordered_streams = vec![]; +/// let in_mem_stream = in_mem_heap_sort(in_mem_batches.drain(..)); +/// partial_ordered_streams.push(in_mem_stream); +/// partial_ordered_streams.extend(spills.drain(..).map(read_as_stream)); +/// let result = sort_preserving_merge(partial_ordered_streams); +struct ExternalSorter { + id: MemoryConsumerId, + schema: SchemaRef, + in_mem_batches: Mutex>, + spills: Mutex>, + /// Sort expressions + expr: Vec, + runtime: Arc, + metrics: ExecutionPlanMetricsSet, + used: AtomicUsize, + spilled_bytes: AtomicUsize, + spilled_count: AtomicUsize, +} + +impl ExternalSorter { + pub fn new( + partition_id: usize, + schema: SchemaRef, + expr: Vec, + runtime: Arc, + ) -> Self { + Self { + id: MemoryConsumerId::new(partition_id), + schema, + in_mem_batches: Mutex::new(vec![]), + spills: Mutex::new(vec![]), + expr, + runtime, + metrics: ExecutionPlanMetricsSet::new(), + used: AtomicUsize::new(0), + spilled_bytes: AtomicUsize::new(0), + spilled_count: AtomicUsize::new(0), + } + } + + async fn insert_batch(&self, input: RecordBatch) -> Result<()> { + let size = batch_byte_size(&input); + self.try_grow(size).await?; + self.used.fetch_add(size, Ordering::SeqCst); + // sort each batch as it's inserted, more probably to be cache-resident + let sorted_batch = sort_batch(input, self.schema.clone(), &*self.expr)?; + let mut in_mem_batches = self.in_mem_batches.lock().await; + in_mem_batches.push(sorted_batch); + Ok(()) + } + + /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. + async fn sort(&self) -> Result { + let partition = self.partition_id(); + let mut in_mem_batches = self.in_mem_batches.lock().await; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let mut streams: Vec = vec![]; + let in_mem_stream = in_mem_partial_sort( + &mut *in_mem_batches, + self.schema.clone(), + &self.expr, + self.runtime.batch_size(), + baseline_metrics, + ) + .await?; + streams.push(SortedStream::new(in_mem_stream, self.used())); + + let mut spills = self.spills.lock().await; + + for spill in spills.drain(..) { + let stream = read_spill_as_stream(spill, self.schema.clone()).await?; + streams.push(SortedStream::new(stream, 0)); + } + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin( + SortPreservingMergeStream::new_from_stream( + streams, + self.schema.clone(), + &self.expr, + self.runtime.batch_size(), + baseline_metrics, + partition, + self.runtime.clone(), + ) + .await, + )) + } + + fn used(&self) -> usize { + self.used.load(Ordering::SeqCst) + } + + fn spilled_bytes(&self) -> usize { + self.spilled_bytes.load(Ordering::SeqCst) + } + + fn spilled_count(&self) -> usize { + self.spilled_count.load(Ordering::SeqCst) + } +} + +impl Debug for ExternalSorter { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ExternalSorter") + .field("id", &self.id()) + .field("memory_used", &self.used()) + .field("spilled_bytes", &self.spilled_bytes()) + .field("spilled_count", &self.spilled_count()) + .finish() + } +} + +#[async_trait] +impl MemoryConsumer for ExternalSorter { + fn name(&self) -> String { + "ExternalSorter".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Requesting + } + + async fn spill(&self) -> Result { + info!( + "{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)", + self.name(), + self.id(), + self.used(), + self.spilled_count() + ); + + let partition = self.partition_id(); + let mut in_mem_batches = self.in_mem_batches.lock().await; + // we could always get a chance to free some memory as long as we are holding some + if in_mem_batches.len() == 0 { + return Ok(0); + } + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + let path = self.runtime.disk_manager.create_tmp_file()?; + let stream = in_mem_partial_sort( + &mut *in_mem_batches, + self.schema.clone(), + &*self.expr, + self.runtime.batch_size(), + baseline_metrics, + ) + .await; + + let total_size = + spill_partial_sorted_stream(&mut stream?, path.clone(), self.schema.clone()) + .await?; + + let mut spills = self.spills.lock().await; + let used = self.used.swap(0, Ordering::SeqCst); + self.spilled_count.fetch_add(1, Ordering::SeqCst); + self.spilled_bytes.fetch_add(total_size, Ordering::SeqCst); + spills.push(path); + Ok(used) + } + + fn mem_used(&self) -> usize { + self.used.load(Ordering::SeqCst) + } +} + +/// consume the `sorted_bathes` and do in_mem_sort +async fn in_mem_partial_sort( + sorted_bathes: &mut Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + target_batch_size: usize, + baseline_metrics: BaselineMetrics, +) -> Result { + if sorted_bathes.len() == 1 { + Ok(Box::pin(SizedRecordBatchStream::new( + schema, + vec![Arc::new(sorted_bathes.pop().unwrap())], + ))) + } else { + let new = sorted_bathes.drain(..).collect(); + assert_eq!(sorted_bathes.len(), 0); + Ok(Box::pin(InMemSortStream::new( + new, + schema, + expressions, + target_batch_size, + baseline_metrics, + )?)) + } +} + +async fn spill_partial_sorted_stream( + in_mem_stream: &mut SendableRecordBatchStream, + path: String, + schema: SchemaRef, +) -> Result { + let (sender, receiver) = tokio::sync::mpsc::channel(2); + while let Some(item) = in_mem_stream.next().await { + sender.send(Some(item)).await.ok(); + } + sender.send(None).await.ok(); + let path_clone = path.clone(); + let res = + task::spawn_blocking(move || write_sorted(receiver, path_clone, schema)).await; + match res { + Ok(r) => r, + Err(e) => Err(DataFusionError::Execution(format!( + "Error occurred while spilling {}", + e + ))), + } +} + +async fn read_spill_as_stream( + path: String, + schema: SchemaRef, +) -> Result { + let (sender, receiver): ( + TKSender>, + TKReceiver>, + ) = tokio::sync::mpsc::channel(2); + let path_clone = path.clone(); + let join_handle = task::spawn_blocking(move || { + if let Err(e) = read_spill(sender, path_clone) { + error!("Failure while reading spill file: {}. Error: {}", path, e); + } + }); + Ok(RecordBatchReceiverStream::create( + &schema, + receiver, + join_handle, + )) +} + +fn write_sorted( + mut receiver: TKReceiver>>, + path: String, + schema: SchemaRef, +) -> Result { + let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; + while let Some(Some(batch)) = receiver.blocking_recv() { + writer.write(&batch?)?; + } + writer.finish()?; + info!( + "Spilled {} batches of total {} rows to disk, memory released {}", + writer.num_batches, writer.num_rows, writer.num_bytes + ); + Ok(writer.num_bytes as usize) +} + +fn read_spill(sender: TKSender>, path: String) -> Result<()> { + let file = BufReader::new(File::open(&path)?); + let reader = FileReader::try_new(file)?; + for batch in reader { + sender + .blocking_send(batch) + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; + } + Ok(()) +} + +/// External Sort execution plan +#[derive(Debug)] +pub struct ExternalSortExec { + /// Input schema + input: Arc, + /// Sort expressions + expr: Vec, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Preserve partitions of input plan + preserve_partitioning: bool, +} + +impl ExternalSortExec { + /// Create a new sort execution plan + pub fn try_new( + expr: Vec, + input: Arc, + ) -> Result { + Ok(Self::new_with_partitioning(expr, input, false)) + } + + /// Create a new sort execution plan with the option to preserve + /// the partitioning of the input plan + pub fn new_with_partitioning( + expr: Vec, + input: Arc, + preserve_partitioning: bool, + ) -> Self { + Self { + expr, + input, + metrics: ExecutionPlanMetricsSet::new(), + preserve_partitioning, + } + } + + /// Input schema + pub fn input(&self) -> &Arc { + &self.input + } + + /// Sort expressions + pub fn expr(&self) -> &[PhysicalSortExpr] { + &self.expr + } +} + +#[async_trait] +impl ExecutionPlan for ExternalSortExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + if self.preserve_partitioning { + self.input.output_partitioning() + } else { + Partitioning::UnknownPartitioning(1) + } + } + + fn required_child_distribution(&self) -> Distribution { + if self.preserve_partitioning { + Distribution::UnspecifiedDistribution + } else { + Distribution::SinglePartition + } + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(ExternalSortExec::try_new( + self.expr.clone(), + children[0].clone(), + )?)), + _ => Err(DataFusionError::Internal( + "SortExec wrong number of children".to_string(), + )), + } + } + + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + if !self.preserve_partitioning { + if 0 != partition { + return Err(DataFusionError::Internal(format!( + "SortExec invalid partition {}", + partition + ))); + } + + // sort needs to operate on a single partition currently + if 1 != self.input.output_partitioning().partition_count() { + return Err(DataFusionError::Internal( + "SortExec requires a single input partition".to_owned(), + )); + } + } + + let _baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input = self.input.execute(partition, runtime.clone()).await?; + + external_sort(input, partition, self.expr.clone(), runtime).await + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); + write!(f, "SortExec: [{}]", expr.join(",")) + } + } + } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } +} + +async fn external_sort( + mut input: SendableRecordBatchStream, + partition_id: usize, + expr: Vec, + runtime: Arc, +) -> Result { + let schema = input.schema(); + let sorter = Arc::new(ExternalSorter::new( + partition_id, + schema.clone(), + expr, + runtime.clone(), + )); + runtime.register_consumer(&(sorter.clone() as Arc)); + + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + + let result = sorter.sort().await; + runtime.drop_consumer(sorter.id()); + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::datasource::object_store::local::LocalFileSystem; + use crate::execution::runtime_env::RuntimeConfig; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::expressions::col; + use crate::physical_plan::{ + collect, + file_format::{CsvExec, PhysicalPlanConfig}, + }; + use crate::test; + use crate::test_util; + use arrow::array::*; + use arrow::compute::SortOptions; + use arrow::datatypes::*; + + async fn sort_with_runtime(runtime: Arc) -> Result> { + let schema = test_util::aggr_test_schema(); + let partitions = 4; + let (_, files) = + test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + + let csv = CsvExec::new( + PhysicalPlanConfig { + object_store: Arc::new(LocalFileSystem {}), + file_schema: Arc::clone(&schema), + file_groups: files, + statistics: Statistics::default(), + projection: None, + batch_size: 1024, + limit: None, + table_partition_cols: vec![], + }, + true, + b',', + ); + + let sort_exec = Arc::new(ExternalSortExec::try_new( + vec![ + // c1 string column + PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }, + // c2 uin32 column + PhysicalSortExpr { + expr: col("c2", &schema)?, + options: SortOptions::default(), + }, + // c7 uin8 column + PhysicalSortExpr { + expr: col("c7", &schema)?, + options: SortOptions::default(), + }, + ], + Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), + )?); + + collect(sort_exec, runtime).await + } + + #[tokio::test] + async fn test_in_mem_sort() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); + let result = sort_with_runtime(runtime).await?; + + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + let c1 = as_string_array(&columns[0]); + assert_eq!(c1.value(0), "a"); + assert_eq!(c1.value(c1.len() - 1), "e"); + + let c2 = as_primitive_array::(&columns[1]); + assert_eq!(c2.value(0), 1); + assert_eq!(c2.value(c2.len() - 1), 5,); + + let c7 = as_primitive_array::(&columns[6]); + assert_eq!(c7.value(0), 15); + assert_eq!(c7.value(c7.len() - 1), 254,); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_spill() -> Result<()> { + let config = RuntimeConfig::new() + .with_memory_fraction(1.0) + // trigger spill there will be 4 batches with 5.5KB for each + .with_max_execution_memory(12288); + let runtime = Arc::new(RuntimeEnv::new(config)?); + let result = sort_with_runtime(runtime).await?; + + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + let c1 = as_string_array(&columns[0]); + assert_eq!(c1.value(0), "a"); + assert_eq!(c1.value(c1.len() - 1), "e"); + + let c2 = as_primitive_array::(&columns[1]); + assert_eq!(c2.value(0), 1); + assert_eq!(c2.value(c2.len() - 1), 5,); + + let c7 = as_primitive_array::(&columns[6]); + assert_eq!(c7.value(0), 15); + assert_eq!(c7.value(c7.len() - 1), 254,); + + Ok(()) + } + + #[tokio::test] + async fn test_multi_output_batch() -> Result<()> { + let config = RuntimeConfig::new().with_batch_size(26); + let runtime = Arc::new(RuntimeEnv::new(config)?); + let result = sort_with_runtime(runtime).await?; + + assert_eq!(result.len(), 4); + + let columns_b1 = result[0].columns(); + let columns_b3 = result[3].columns(); + + let c1 = as_string_array(&columns_b1[0]); + let c13 = as_string_array(&columns_b3[0]); + assert_eq!(c1.value(0), "a"); + assert_eq!(c13.value(c13.len() - 1), "e"); + + let c2 = as_primitive_array::(&columns_b1[1]); + let c23 = as_primitive_array::(&columns_b3[1]); + assert_eq!(c2.value(0), 1); + assert_eq!(c23.value(c23.len() - 1), 5,); + + let c7 = as_primitive_array::(&columns_b1[6]); + let c73 = as_primitive_array::(&columns_b3[6]); + assert_eq!(c7.value(0), 15); + assert_eq!(c73.value(c73.len() - 1), 254,); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/sorts/in_mem_sort.rs b/datafusion/src/physical_plan/sorts/in_mem_sort.rs new file mode 100644 index 000000000000..9e7753d42472 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/in_mem_sort.rs @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::BinaryHeap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::{ + array::{make_array as make_arrow_array, MutableArrayData}, + compute::SortOptions, + datatypes::SchemaRef, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, +}; +use futures::Stream; + +use crate::error::Result; +use crate::physical_plan::metrics::BaselineMetrics; +use crate::physical_plan::sorts::{RowIndex, SortKeyCursor}; +use crate::physical_plan::{ + expressions::PhysicalSortExpr, PhysicalExpr, RecordBatchStream, +}; + +/// Merge buffered, self-sorted record batches to get an order. +/// +/// Internally, it uses MinHeap to reduce extra memory consumption +/// by not concatenating all batches into one and sorting it as done by `SortExec`. +pub(crate) struct InMemSortStream { + /// The schema of the RecordBatches yielded by this stream + schema: SchemaRef, + /// Self sorted batches to be merged together + batches: Vec>, + /// The accumulated row indexes for the next record batch + in_progress: Vec, + /// The desired RecordBatch size to yield + target_batch_size: usize, + /// used to record execution metrics + baseline_metrics: BaselineMetrics, + /// If the stream has encountered an error + aborted: bool, + /// min heap for record comparison + min_heap: BinaryHeap, +} + +impl InMemSortStream { + pub(crate) fn new( + sorted_batches: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + target_batch_size: usize, + baseline_metrics: BaselineMetrics, + ) -> Result { + let len = sorted_batches.len(); + let mut cursors = Vec::with_capacity(len); + let mut min_heap = BinaryHeap::with_capacity(len); + + let column_expressions: Vec> = + expressions.iter().map(|x| x.expr.clone()).collect(); + + // The sort options for each expression + let sort_options: Arc> = + Arc::new(expressions.iter().map(|x| x.options).collect()); + + sorted_batches + .into_iter() + .enumerate() + .try_for_each(|(idx, batch)| { + let batch = Arc::new(batch); + let cursor = match SortKeyCursor::new( + idx, + batch.clone(), + &column_expressions, + sort_options.clone(), + ) { + Ok(cursor) => cursor, + Err(e) => return Err(e), + }; + min_heap.push(cursor); + cursors.insert(idx, batch); + Ok(()) + })?; + + Ok(Self { + schema, + batches: cursors, + target_batch_size, + baseline_metrics, + aborted: false, + in_progress: vec![], + min_heap, + }) + } + + /// Returns the index of the next batch to pull a row from, or None + /// if all cursors for all batch are exhausted + fn next_cursor(&mut self) -> Result> { + match self.min_heap.pop() { + None => Ok(None), + Some(cursor) => Ok(Some(cursor)), + } + } + + /// Drains the in_progress row indexes, and builds a new RecordBatch from them + /// + /// Will then drop any cursors for which all rows have been yielded to the output + fn build_record_batch(&mut self) -> ArrowResult { + let columns = self + .schema + .fields() + .iter() + .enumerate() + .map(|(column_idx, field)| { + let arrays = self + .batches + .iter() + .map(|batch| batch.column(column_idx).data()) + .collect(); + + let mut array_data = MutableArrayData::new( + arrays, + field.is_nullable(), + self.in_progress.len(), + ); + + if self.in_progress.is_empty() { + return make_arrow_array(array_data.freeze()); + } + + let first = &self.in_progress[0]; + let mut buffer_idx = first.stream_idx; + let mut start_row_idx = first.row_idx; + let mut end_row_idx = start_row_idx + 1; + + for row_index in self.in_progress.iter().skip(1) { + let next_buffer_idx = row_index.stream_idx; + + if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { + // subsequent row in same batch + end_row_idx += 1; + continue; + } + + // emit current batch of rows for current buffer + array_data.extend(buffer_idx, start_row_idx, end_row_idx); + + // start new batch of rows + buffer_idx = next_buffer_idx; + start_row_idx = row_index.row_idx; + end_row_idx = start_row_idx + 1; + } + + // emit final batch of rows + array_data.extend(buffer_idx, start_row_idx, end_row_idx); + make_arrow_array(array_data.freeze()) + }) + .collect(); + + self.in_progress.clear(); + RecordBatch::try_new(self.schema.clone(), columns) + } + + #[inline] + fn poll_next_inner( + self: &mut Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.aborted { + return Poll::Ready(None); + } + + loop { + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + + match self.next_cursor() { + Ok(Some(mut cursor)) => { + let batch_idx = cursor.batch_idx; + let row_idx = cursor.advance(); + + // insert the cursor back to min_heap if the record batch is not exhausted + if !cursor.is_finished() { + self.min_heap.push(cursor); + } + + self.in_progress.push(RowIndex { + stream_idx: batch_idx, + cursor_idx: 0, + row_idx, + }); + } + Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None), + Ok(None) => return Poll::Ready(Some(self.build_record_batch())), + Err(e) => { + self.aborted = true; + return Poll::Ready(Some(Err(ArrowError::ExternalError(Box::new( + e, + ))))); + } + }; + + if self.in_progress.len() == self.target_batch_size { + return Poll::Ready(Some(self.build_record_batch())); + } + } + } +} + +impl Stream for InMemSortStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.baseline_metrics.record_poll(poll) + } +} + +impl RecordBatchStream for InMemSortStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs new file mode 100644 index 000000000000..3dda13b1a178 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort functionalities + +use crate::error; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{PhysicalExpr, SendableRecordBatchStream}; +use arrow::array::{ArrayRef, DynComparator}; +use arrow::compute::SortOptions; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use futures::channel::mpsc; +use futures::stream::FusedStream; +use futures::Stream; +use hashbrown::HashMap; +use std::borrow::BorrowMut; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; + +pub mod external_sort; +mod in_mem_sort; +pub mod sort; +pub mod sort_preserving_merge; + +/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of +/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. +/// +/// Additionally it maintains a row cursor that can be advanced through the rows +/// of the provided `RecordBatch` +/// +/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to +/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores +/// a row comparator for each other cursor that it is compared to. +struct SortKeyCursor { + columns: Vec, + cur_row: usize, + num_rows: usize, + + // An index uniquely identifying the record batch scanned by this cursor. + batch_idx: usize, + batch: Arc, + + // A collection of comparators that compare rows in this cursor's batch to + // the cursors in other batches. Other batches are uniquely identified by + // their batch_idx. + batch_comparators: RwLock>>, + sort_options: Arc>, +} + +impl<'a> std::fmt::Debug for SortKeyCursor { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("SortKeyCursor") + .field("columns", &self.columns) + .field("cur_row", &self.cur_row) + .field("num_rows", &self.num_rows) + .field("batch_idx", &self.batch_idx) + .field("batch", &self.batch) + .field("batch_comparators", &"") + .finish() + } +} + +impl SortKeyCursor { + fn new( + batch_idx: usize, + batch: Arc, + sort_key: &[Arc], + sort_options: Arc>, + ) -> error::Result { + let columns = sort_key + .iter() + .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) + .collect::>()?; + Ok(Self { + cur_row: 0, + num_rows: batch.num_rows(), + columns, + batch, + batch_idx, + batch_comparators: RwLock::new(HashMap::new()), + sort_options, + }) + } + + fn is_finished(&self) -> bool { + self.num_rows == self.cur_row + } + + fn advance(&mut self) -> usize { + assert!(!self.is_finished()); + let t = self.cur_row; + self.cur_row += 1; + t + } + + /// Compares the sort key pointed to by this instance's row cursor with that of another + fn compare(&self, other: &SortKeyCursor) -> error::Result { + if self.columns.len() != other.columns.len() { + return Err(DataFusionError::Internal(format!( + "SortKeyCursors had inconsistent column counts: {} vs {}", + self.columns.len(), + other.columns.len() + ))); + } + + if self.columns.len() != self.sort_options.len() { + return Err(DataFusionError::Internal(format!( + "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", + self.columns.len(), + self.sort_options.len() + ))); + } + + let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self + .columns + .iter() + .zip(other.columns.iter()) + .zip(self.sort_options.iter()) + .collect::>(); + + self.init_cmp_if_needed(other, &zipped)?; + let map = self.batch_comparators.read().unwrap(); + let cmp = map.get(&other.batch_idx).ok_or_else(|| { + DataFusionError::Execution(format!( + "Failed to find comparator for {} cmp {}", + self.batch_idx, other.batch_idx + )) + })?; + + for (i, ((l, r), sort_options)) in zipped.iter().enumerate() { + match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { + (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), + (false, true) => return Ok(Ordering::Greater), + (true, false) if sort_options.nulls_first => { + return Ok(Ordering::Greater) + } + (true, false) => return Ok(Ordering::Less), + (false, false) => {} + (true, true) => match cmp[i](self.cur_row, other.cur_row) { + Ordering::Equal => {} + o if sort_options.descending => return Ok(o.reverse()), + o => return Ok(o), + }, + } + } + + Ok(Ordering::Equal) + } + + /// Initialize a collection of comparators for comparing + /// columnar arrays of this cursor and "other" if needed. + fn init_cmp_if_needed( + &self, + other: &SortKeyCursor, + zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)], + ) -> Result<()> { + let hm = self.batch_comparators.read().unwrap(); + if !hm.contains_key(&other.batch_idx) { + drop(hm); + let mut map = self.batch_comparators.write().unwrap(); + let cmp = map + .borrow_mut() + .entry(other.batch_idx) + .or_insert_with(|| Vec::with_capacity(other.columns.len())); + + for (i, ((l, r), _)) in zipped.iter().enumerate() { + if i >= cmp.len() { + // initialise comparators + cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?); + } + } + } + Ok(()) + } +} + +impl Ord for SortKeyCursor { + /// Needed by min-heap comparison in `in_mem_sort` and reverse the order at the same time. + fn cmp(&self, other: &Self) -> Ordering { + other.compare(self).unwrap() + } +} + +impl PartialEq for SortKeyCursor { + fn eq(&self, other: &Self) -> bool { + other.compare(self).unwrap() == Ordering::Equal + } +} + +impl Eq for SortKeyCursor {} + +impl PartialOrd for SortKeyCursor { + fn partial_cmp(&self, other: &Self) -> Option { + other.compare(self).ok() + } +} + +/// A `RowIndex` identifies a specific row from those buffered +/// by a `SortPreservingMergeStream` +#[derive(Debug, Clone)] +struct RowIndex { + /// The index of the stream + stream_idx: usize, + /// For sort_preserving_merge, it's the index of the cursor within the stream's VecDequeue. + /// For in_mem_sort which have only one batch for each stream, cursor_idx always 0 + cursor_idx: usize, + /// The row index + row_idx: usize, +} + +pub(crate) struct SortedStream { + stream: SendableRecordBatchStream, + mem_used: usize, +} + +impl Debug for SortedStream { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "InMemSorterStream") + } +} + +impl SortedStream { + pub(crate) fn new(stream: SendableRecordBatchStream, mem_used: usize) -> Self { + Self { stream, mem_used } + } +} + +#[derive(Debug)] +enum StreamWrapper { + Receiver(mpsc::Receiver>), + Stream(Option), +} + +impl StreamWrapper { + fn mem_used(&self) -> usize { + if let StreamWrapper::Stream(Some(s)) = &self { + s.mem_used + } else { + 0 + } + } +} + +impl Stream for StreamWrapper { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + StreamWrapper::Receiver(ref mut receiver) => Pin::new(receiver).poll_next(cx), + StreamWrapper::Stream(ref mut stream) => { + let inner = match stream { + None => return Poll::Ready(None), + Some(inner) => inner, + }; + + match Pin::new(&mut inner.stream).poll_next(cx) { + Poll::Ready(msg) => { + if msg.is_none() { + *stream = None + } + Poll::Ready(msg) + } + Poll::Pending => Poll::Pending, + } + } + } + } +} + +impl FusedStream for StreamWrapper { + fn is_terminated(&self) -> bool { + match self { + StreamWrapper::Receiver(receiver) => receiver.is_terminated(), + StreamWrapper::Stream(stream) => stream.is_none(), + } + } +} diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs similarity index 95% rename from datafusion/src/physical_plan/sort.rs rename to datafusion/src/physical_plan/sorts/sort.rs index dec9a9136a5d..678ad03745d9 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -17,16 +17,17 @@ //! Defines the SORT plan -use super::common::AbortOnDropSingle; -use super::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, -}; -use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, +}; use crate::physical_plan::{ common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, }; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream, Statistics}; pub use arrow::compute::SortOptions; use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; use arrow::datatypes::SchemaRef; @@ -137,7 +138,11 @@ impl ExecutionPlan for SortExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if !self.preserve_partitioning { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -155,7 +160,7 @@ impl ExecutionPlan for SortExec { } let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let input = self.input.execute(partition).await?; + let input = self.input.execute(partition, runtime).await?; Ok(Box::pin(SortStream::new( input, @@ -186,7 +191,7 @@ impl ExecutionPlan for SortExec { } } -fn sort_batch( +pub(crate) fn sort_batch( batch: RecordBatch, schema: SchemaRef, expr: &[PhysicalSortExpr], @@ -330,6 +335,7 @@ mod tests { #[tokio::test] async fn test_sort() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = @@ -371,7 +377,7 @@ mod tests { Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), )?); - let result: Vec = collect(sort_exec).await?; + let result: Vec = collect(sort_exec, runtime).await?; assert_eq!(result.len(), 1); let columns = result[0].columns(); @@ -393,6 +399,7 @@ mod tests { #[tokio::test] async fn test_sort_metadata() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let field_metadata: BTreeMap = vec![("foo".to_string(), "bar".to_string())] .into_iter() @@ -422,7 +429,7 @@ mod tests { input, )?); - let result: Vec = collect(sort_exec).await?; + let result: Vec = collect(sort_exec, runtime).await?; let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); @@ -444,6 +451,7 @@ mod tests { #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float64, true), @@ -499,7 +507,7 @@ mod tests { assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); - let result: Vec = collect(sort_exec.clone()).await?; + let result: Vec = collect(sort_exec.clone(), runtime).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 8); @@ -548,6 +556,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -561,7 +570,7 @@ mod tests { blocking_exec, )?); - let fut = collect(sort_exec); + let fut = collect(sort_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs similarity index 83% rename from datafusion/src/physical_plan/sort_preserving_merge.rs rename to datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 632658058e3b..fa49daf5a1a6 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -17,18 +17,20 @@ //! Defines the sort preserving merge plan -use super::common::AbortOnDropMany; -use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::physical_plan::common::AbortOnDropMany; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; +use std::fmt::{Debug, Formatter}; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; -use arrow::array::DynComparator; use arrow::{ - array::{make_array as make_arrow_array, ArrayRef, MutableArrayData}, + array::{make_array as make_arrow_array, MutableArrayData}, compute::SortOptions, datatypes::SchemaRef, error::{ArrowError, Result as ArrowResult}, @@ -38,9 +40,13 @@ use async_trait::async_trait; use futures::channel::mpsc; use futures::stream::FusedStream; use futures::{Stream, StreamExt}; -use hashbrown::HashMap; use crate::error::{DataFusionError, Result}; +use crate::execution::memory_manager::{ + ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, +}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, StreamWrapper}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, @@ -131,7 +137,11 @@ impl ExecutionPlan for SortPreservingMergeExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "SortPreservingMergeExec invalid partition {}", @@ -149,27 +159,36 @@ impl ExecutionPlan for SortPreservingMergeExec { )), 1 => { // bypass if there is only one partition to merge (no metrics in this case either) - self.input.execute(0).await + self.input.execute(0, runtime).await } _ => { let (receivers, join_handles) = (0..input_partitions) .into_iter() .map(|part_i| { let (sender, receiver) = mpsc::channel(1); - let join_handle = - spawn_execution(self.input.clone(), sender, part_i); + let join_handle = spawn_execution( + self.input.clone(), + sender, + part_i, + runtime.clone(), + ); (receiver, join_handle) }) .unzip(); - Ok(Box::pin(SortPreservingMergeStream::new( - receivers, - AbortOnDropMany(join_handles), - self.schema(), - &self.expr, - self.target_batch_size, - baseline_metrics, - ))) + Ok(Box::pin( + SortPreservingMergeStream::new_from_receiver( + receivers, + AbortOnDropMany(join_handles), + self.schema(), + &self.expr, + self.target_batch_size, + baseline_metrics, + partition, + runtime.clone(), + ) + .await, + )) } } } @@ -196,154 +215,76 @@ impl ExecutionPlan for SortPreservingMergeExec { } } -/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of -/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. -/// -/// Additionally it maintains a row cursor that can be advanced through the rows -/// of the provided `RecordBatch` -/// -/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to -/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores -/// a row comparator for each other cursor that it is compared to. -struct SortKeyCursor { - columns: Vec, - cur_row: usize, - num_rows: usize, - - // An index uniquely identifying the record batch scanned by this cursor. - batch_idx: usize, - batch: RecordBatch, - - // A collection of comparators that compare rows in this cursor's batch to - // the cursors in other batches. Other batches are uniquely identified by - // their batch_idx. - batch_comparators: HashMap>, +struct MergingStreams { + /// ConsumerId + id: MemoryConsumerId, + /// The sorted input streams to merge together + pub(crate) streams: Mutex>, + /// Runtime + runtime: Arc, } -impl<'a> std::fmt::Debug for SortKeyCursor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SortKeyCursor") - .field("columns", &self.columns) - .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) - .field("batch_idx", &self.batch_idx) - .field("batch", &self.batch) - .field("batch_comparators", &"") +impl Debug for MergingStreams { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("MergingStreams") + .field("id", &self.id()) .finish() } } -impl SortKeyCursor { - fn new( - batch_idx: usize, - batch: RecordBatch, - sort_key: &[Arc], - ) -> Result { - let columns = sort_key - .iter() - .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) - .collect::>()?; - Ok(Self { - cur_row: 0, - num_rows: batch.num_rows(), - columns, - batch, - batch_idx, - batch_comparators: HashMap::new(), - }) +impl MergingStreams { + pub fn new( + partition: usize, + input_streams: Vec, + runtime: Arc, + ) -> Self { + Self { + id: MemoryConsumerId::new(partition), + streams: Mutex::new(input_streams), + runtime, + } } +} - fn is_finished(&self) -> bool { - self.num_rows == self.cur_row +#[async_trait] +impl MemoryConsumer for MergingStreams { + fn name(&self) -> String { + "MergingStreams".to_owned() } - fn advance(&mut self) -> usize { - assert!(!self.is_finished()); - let t = self.cur_row; - self.cur_row += 1; - t + fn id(&self) -> &MemoryConsumerId { + &self.id } - /// Compares the sort key pointed to by this instance's row cursor with that of another - fn compare( - &mut self, - other: &SortKeyCursor, - options: &[SortOptions], - ) -> Result { - if self.columns.len() != other.columns.len() { - return Err(DataFusionError::Internal(format!( - "SortKeyCursors had inconsistent column counts: {} vs {}", - self.columns.len(), - other.columns.len() - ))); - } - - if self.columns.len() != options.len() { - return Err(DataFusionError::Internal(format!( - "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", - self.columns.len(), - options.len() - ))); - } - - let zipped = self - .columns - .iter() - .zip(other.columns.iter()) - .zip(options.iter()); - - // Recall or initialise a collection of comparators for comparing - // columnar arrays of this cursor and "other". - let cmp = self - .batch_comparators - .entry(other.batch_idx) - .or_insert_with(|| Vec::with_capacity(other.columns.len())); - - for (i, ((l, r), sort_options)) in zipped.enumerate() { - if i >= cmp.len() { - // initialise comparators as potentially needed - cmp.push(arrow::array::build_compare(l.as_ref(), r.as_ref())?); - } + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } - match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { - (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), - (false, true) => return Ok(Ordering::Greater), - (true, false) if sort_options.nulls_first => { - return Ok(Ordering::Greater) - } - (true, false) => return Ok(Ordering::Less), - (false, false) => {} - (true, true) => match cmp[i](self.cur_row, other.cur_row) { - Ordering::Equal => {} - o if sort_options.descending => return Ok(o.reverse()), - o => return Ok(o), - }, - } - } + fn type_(&self) -> &ConsumerType { + &ConsumerType::Tracking + } - Ok(Ordering::Equal) + async fn spill(&self) -> Result { + return Err(DataFusionError::Internal(format!( + "Calling spill on a tracking only consumer {}, {}", + self.name(), + self.id, + ))); } -} -/// A `RowIndex` identifies a specific row from those buffered -/// by a `SortPreservingMergeStream` -#[derive(Debug, Clone)] -struct RowIndex { - /// The index of the stream - stream_idx: usize, - /// The index of the cursor within the stream's VecDequeue - cursor_idx: usize, - /// The row index - row_idx: usize, + fn mem_used(&self) -> usize { + let streams = self.streams.lock().unwrap(); + streams.iter().map(StreamWrapper::mem_used).sum::() + } } #[derive(Debug)] -struct SortPreservingMergeStream { +pub(crate) struct SortPreservingMergeStream { /// The schema of the RecordBatches yielded by this stream schema: SchemaRef, /// The sorted input streams to merge together - receivers: Vec>>, + streams: Arc, /// Drop helper for tasks feeding the [`receivers`](Self::receivers) _drop_helper: AbortOnDropMany<()>, @@ -361,7 +302,7 @@ struct SortPreservingMergeStream { column_expressions: Vec>, /// The sort options for each expression - sort_options: Vec, + sort_options: Arc>, /// The desired RecordBatch size to yield target_batch_size: usize, @@ -374,34 +315,89 @@ struct SortPreservingMergeStream { /// An index to uniquely identify the input stream batch next_batch_index: usize, + + /// runtime + runtime: Arc, +} + +impl Drop for SortPreservingMergeStream { + fn drop(&mut self) { + self.runtime.drop_consumer(self.streams.id()) + } } impl SortPreservingMergeStream { - fn new( + #[allow(clippy::too_many_arguments)] + pub(crate) async fn new_from_receiver( receivers: Vec>>, _drop_helper: AbortOnDropMany<()>, schema: SchemaRef, expressions: &[PhysicalSortExpr], target_batch_size: usize, baseline_metrics: BaselineMetrics, + partition: usize, + runtime: Arc, ) -> Self { let cursors = (0..receivers.len()) .into_iter() .map(|_| VecDeque::new()) .collect(); + let wrappers = receivers.into_iter().map(StreamWrapper::Receiver).collect(); + let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); + runtime.register_consumer(&(streams.clone() as Arc)); + Self { schema, cursors, - receivers, + streams, _drop_helper, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), - sort_options: expressions.iter().map(|x| x.options).collect(), + sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), target_batch_size, baseline_metrics, aborted: false, in_progress: vec![], next_batch_index: 0, + runtime, + } + } + + pub(crate) async fn new_from_stream( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + target_batch_size: usize, + baseline_metrics: BaselineMetrics, + partition: usize, + runtime: Arc, + ) -> Self { + let cursors = (0..streams.len()) + .into_iter() + .map(|_| VecDeque::new()) + .collect(); + + let wrappers = streams + .into_iter() + .map(|s| StreamWrapper::Stream(Some(s))) + .collect::>(); + + let streams = Arc::new(MergingStreams::new(partition, wrappers, runtime.clone())); + runtime.register_consumer(&(streams.clone() as Arc)); + + Self { + schema, + cursors, + streams, + _drop_helper: AbortOnDropMany(vec![]), + column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), + target_batch_size, + baseline_metrics, + aborted: false, + in_progress: vec![], + next_batch_index: 0, + runtime, } } @@ -420,7 +416,9 @@ impl SortPreservingMergeStream { } } - let stream = &mut self.receivers[idx]; + let mut streams = self.streams.streams.lock().unwrap(); + + let stream = &mut streams[idx]; if stream.is_terminated() { return Poll::Ready(Ok(())); } @@ -434,8 +432,9 @@ impl SortPreservingMergeStream { Some(Ok(batch)) => { let cursor = match SortKeyCursor::new( self.next_batch_index, // assign this batch an ID - batch, + Arc::new(batch), &self.column_expressions, + self.sort_options.clone(), ) { Ok(cursor) => cursor, Err(e) => { @@ -463,9 +462,7 @@ impl SortPreservingMergeStream { match min_cursor { None => min_cursor = Some((idx, candidate)), Some((_, ref mut min)) => { - if min.compare(candidate, &self.sort_options)? - == Ordering::Greater - { + if min.compare(candidate)? == Ordering::Greater { min_cursor = Some((idx, candidate)) } } @@ -661,6 +658,7 @@ mod tests { use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::metrics::MetricValue; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use arrow::array::ArrayRef; use std::iter::FromIterator; use crate::arrow::array::{Int32Array, StringArray, TimestampNanosecondArray}; @@ -668,7 +666,7 @@ mod tests { use crate::physical_plan::expressions::col; use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::sort::SortExec; + use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{collect, common}; use crate::test::{self, assert_is_pending}; use crate::{assert_batches_eq, test_util}; @@ -680,6 +678,7 @@ mod tests { #[tokio::test] async fn test_merge_interleave() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -720,12 +719,14 @@ mod tests { "| 3 | j | 1970-01-01 00:00:00.000000008 |", "+----+---+-------------------------------+", ], + runtime, ) .await; } #[tokio::test] async fn test_merge_some_overlap() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -766,12 +767,14 @@ mod tests { "| 110 | g | 1970-01-01 00:00:00.000000006 |", "+-----+---+-------------------------------+", ], + runtime, ) .await; } #[tokio::test] async fn test_merge_no_overlap() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -812,12 +815,14 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000006 |", "+----+---+-------------------------------+", ], + runtime, ) .await; } #[tokio::test] async fn test_merge_three_partitions() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -875,11 +880,16 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000060 |", "+-----+---+-------------------------------+", ], + runtime, ) .await; } - async fn _test_merge(partitions: &[Vec], exp: &[&str]) { + async fn _test_merge( + partitions: &[Vec], + exp: &[&str], + runtime: Arc, + ) { let schema = partitions[0][0].schema(); let sort = vec![ PhysicalSortExpr { @@ -894,16 +904,17 @@ mod tests { let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); - let collected = collect(merge).await.unwrap(); + let collected = collect(merge, runtime).await.unwrap(); assert_batches_eq!(exp, collected.as_slice()); } async fn sorted_merge( input: Arc, sort: Vec, + runtime: Arc, ) -> RecordBatch { let merge = Arc::new(SortPreservingMergeExec::new(sort, input, 1024)); - let mut result = collect(merge).await.unwrap(); + let mut result = collect(merge, runtime).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } @@ -911,25 +922,28 @@ mod tests { async fn partition_sort( input: Arc, sort: Vec, + runtime: Arc, ) -> RecordBatch { let sort_exec = Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true)); - sorted_merge(sort_exec, sort).await + sorted_merge(sort_exec, sort, runtime).await } async fn basic_sort( src: Arc, sort: Vec, + runtime: Arc, ) -> RecordBatch { let merge = Arc::new(CoalescePartitionsExec::new(src)); let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap()); - let mut result = collect(sort_exec).await.unwrap(); + let mut result = collect(sort_exec, runtime).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } #[tokio::test] async fn test_partition_sort() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = @@ -972,8 +986,8 @@ mod tests { }, ]; - let basic = basic_sort(csv.clone(), sort.clone()).await; - let partition = partition_sort(csv, sort).await; + let basic = basic_sort(csv.clone(), sort.clone(), runtime.clone()).await; + let partition = partition_sort(csv, sort, runtime.clone()).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -1016,6 +1030,7 @@ mod tests { async fn sorted_partitioned_input( sort: Vec, sizes: &[usize], + runtime: Arc, ) -> Arc { let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -1037,7 +1052,7 @@ mod tests { b',', )); - let sorted = basic_sort(csv, sort).await; + let sorted = basic_sort(csv, sort, runtime).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); Arc::new(MemoryExec::try_new(&split, sorted.schema(), None).unwrap()) @@ -1045,6 +1060,7 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![ // uint8 @@ -1069,9 +1085,10 @@ mod tests { }, ]; - let input = sorted_partitioned_input(sort.clone(), &[10, 3, 11]).await; - let basic = basic_sort(input.clone(), sort.clone()).await; - let partition = sorted_merge(input, sort).await; + let input = + sorted_partitioned_input(sort.clone(), &[10, 3, 11], runtime.clone()).await; + let basic = basic_sort(input.clone(), sort.clone(), runtime.clone()).await; + let partition = sorted_merge(input, sort, runtime.clone()).await; assert_eq!(basic.num_rows(), 300); assert_eq!(partition.num_rows(), 300); @@ -1088,6 +1105,7 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![ @@ -1103,11 +1121,12 @@ mod tests { }, ]; - let input = sorted_partitioned_input(sort.clone(), &[10, 5, 13]).await; - let basic = basic_sort(input.clone(), sort.clone()).await; + let input = + sorted_partitioned_input(sort.clone(), &[10, 5, 13], runtime.clone()).await; + let basic = basic_sort(input.clone(), sort.clone(), runtime.clone()).await; let merge = Arc::new(SortPreservingMergeExec::new(sort, input, 23)); - let merged = collect(merge).await.unwrap(); + let merged = collect(merge, runtime.clone()).await.unwrap(); assert_eq!(merged.len(), 14); @@ -1126,6 +1145,7 @@ mod tests { #[tokio::test] async fn test_nulls() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ None, @@ -1180,7 +1200,7 @@ mod tests { let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); - let collected = collect(merge).await.unwrap(); + let collected = collect(merge, runtime).await.unwrap(); assert_eq!(collected.len(), 1); assert_batches_eq!( @@ -1206,13 +1226,15 @@ mod tests { #[tokio::test] async fn test_async() { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let sort = vec![PhysicalSortExpr { expr: col("c12", &schema).unwrap(), options: SortOptions::default(), }]; - let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3]).await; + let batches = + sorted_partitioned_input(sort.clone(), &[5, 7, 3], runtime.clone()).await; let partition_count = batches.output_partitioning().partition_count(); let mut join_handles = Vec::with_capacity(partition_count); @@ -1220,7 +1242,7 @@ mod tests { for partition in 0..partition_count { let (mut sender, receiver) = mpsc::channel(1); - let mut stream = batches.execute(partition).await.unwrap(); + let mut stream = batches.execute(partition, runtime.clone()).await.unwrap(); let join_handle = tokio::spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); @@ -1235,7 +1257,7 @@ mod tests { let metrics = ExecutionPlanMetricsSet::new(); let baseline_metrics = BaselineMetrics::new(&metrics, 0); - let merge_stream = SortPreservingMergeStream::new( + let merge_stream = SortPreservingMergeStream::new_from_receiver( receivers, // Use empty vector since we want to use the join handles ourselves AbortOnDropMany(vec![]), @@ -1243,7 +1265,10 @@ mod tests { sort.as_slice(), 1024, baseline_metrics, - ); + 0, + runtime.clone(), + ) + .await; let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); @@ -1254,7 +1279,7 @@ mod tests { assert_eq!(merged.len(), 1); let merged = merged.remove(0); - let basic = basic_sort(batches, sort.clone()).await; + let basic = basic_sort(batches, sort.clone(), runtime.clone()).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -1272,6 +1297,7 @@ mod tests { #[tokio::test] async fn test_merge_metrics() { + let runtime = Arc::new(RuntimeEnv::default()); let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); @@ -1288,7 +1314,7 @@ mod tests { let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); - let collected = collect(merge.clone()).await.unwrap(); + let collected = collect(merge.clone(), runtime).await.unwrap(); let expected = vec![ "+----+---+", "| a | b |", @@ -1327,6 +1353,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1341,7 +1368,7 @@ mod tests { 1, )); - let fut = collect(sort_preserving_merge_exec); + let fut = collect(sort_preserving_merge_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 08ea5d30946e..974b4a9df764 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -56,7 +56,7 @@ pub struct AggregateUDF { } impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("AggregateUDF") .field("name", &self.name) .field("signature", &self.signature) diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 0c5e80baea31..af0765877c1b 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -54,7 +54,7 @@ pub struct ScalarUDF { } impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("ScalarUDF") .field("name", &self.name) .field("signature", &self.signature) diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 79c50720496d..efbc62359f46 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -31,6 +31,7 @@ use super::{ ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::execution::runtime_env::RuntimeEnv; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -91,7 +92,11 @@ impl ExecutionPlan for UnionExec { Ok(Arc::new(UnionExec::new(children))) } - async fn execute(&self, mut partition: usize) -> Result { + async fn execute( + &self, + mut partition: usize, + runtime: Arc, + ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); // record the tiny amount of work done in this function so // elapsed_compute is reported as non zero @@ -102,7 +107,7 @@ impl ExecutionPlan for UnionExec { for input in self.inputs.iter() { // Calculate whether partition belongs to the current partition if partition < input.output_partitioning().partition_count() { - let stream = input.execute(partition).await?; + let stream = input.execute(partition, runtime.clone()).await?; return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))); } else { partition -= input.output_partitioning().partition_count(); @@ -232,6 +237,7 @@ mod tests { #[tokio::test] async fn test_union_partitions() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = test_util::aggr_test_schema(); let fs: Arc = Arc::new(LocalFileSystem {}); @@ -274,7 +280,7 @@ mod tests { // Should have 9 partitions and 9 output batches assert_eq!(union_exec.output_partitioning().partition_count(), 9); - let result: Vec = collect(union_exec).await?; + let result: Vec = collect(union_exec, runtime).await?; assert_eq!(result.len(), 9); Ok(()) diff --git a/datafusion/src/physical_plan/values.rs b/datafusion/src/physical_plan/values.rs index f4f8ccb6246a..c3a7ea5c162c 100644 --- a/datafusion/src/physical_plan/values.rs +++ b/datafusion/src/physical_plan/values.rs @@ -19,6 +19,7 @@ use super::{common, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::{ memory::MemoryStream, ColumnarValue, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, @@ -133,7 +134,11 @@ impl ExecutionPlan for ValuesExec { } } - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { return Err(DataFusionError::Internal(format!( diff --git a/datafusion/src/physical_plan/windows/mod.rs b/datafusion/src/physical_plan/windows/mod.rs index 497cbc3c446d..42bc27c46283 100644 --- a/datafusion/src/physical_plan/windows/mod.rs +++ b/datafusion/src/physical_plan/windows/mod.rs @@ -174,6 +174,7 @@ pub(crate) fn find_ranges_in_range<'a>( mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; + use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::col; use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig}; @@ -211,6 +212,7 @@ mod tests { #[tokio::test] async fn window_function() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let (input, schema) = create_test_schema(1)?; let window_exec = Arc::new(WindowAggExec::try_new( @@ -247,7 +249,7 @@ mod tests { schema.clone(), )?); - let result: Vec = collect(window_exec).await?; + let result: Vec = collect(window_exec, runtime).await?; assert_eq!(result.len(), 1); let columns = result[0].columns(); @@ -271,6 +273,7 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { + let runtime = Arc::new(RuntimeEnv::default()); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -290,7 +293,7 @@ mod tests { schema, )?); - let fut = collect(window_agg_exec); + let fut = collect(window_agg_exec, runtime); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs b/datafusion/src/physical_plan/windows/window_agg_exec.rs index 228b53f2be3e..b86ac1b02385 100644 --- a/datafusion/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs @@ -18,6 +18,7 @@ //! Stream and channel implementations for window function expressions. use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, @@ -141,8 +142,12 @@ impl ExecutionPlan for WindowAggExec { } } - async fn execute(&self, partition: usize) -> Result { - let input = self.input.execute(partition).await?; + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { + let input = self.input.execute(partition, runtime).await?; let stream = Box::pin(WindowAggStream::new( self.schema.clone(), self.window_expr.clone(), diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index cf6e8a1ac1c2..6f80e9b57780 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -2108,7 +2108,7 @@ impl fmt::Display for ScalarValue { } impl fmt::Debug for ScalarValue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index 8351c9bccf97..39b8e5c11f5b 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -33,6 +33,7 @@ use arrow::{ }; use futures::Stream; +use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -162,7 +163,11 @@ impl ExecutionPlan for MockExec { } /// Returns a stream which yields data - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { assert_eq!(partition, 0); // Result doesn't implement clone, so do it ourself @@ -293,7 +298,11 @@ impl ExecutionPlan for BarrierExec { } /// Returns a stream which yields data - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { assert!(partition < self.data.len()); let (tx, rx) = tokio::sync::mpsc::channel(2); @@ -386,7 +395,11 @@ impl ExecutionPlan for ErrorExec { } /// Returns a stream which yields data - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + _runtime: Arc, + ) -> Result { Err(DataFusionError::Internal(format!( "ErrorExec, unsurprisingly, errored in partition {}", partition @@ -463,7 +476,11 @@ impl ExecutionPlan for StatisticsExec { } } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -553,7 +570,11 @@ impl ExecutionPlan for BlockingExec { ))) } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(BlockingStream { schema: Arc::clone(&self.schema), _refs: Arc::clone(&self.refs), diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index b1288f7b5f63..4f027e903ec0 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -45,6 +45,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::Projection; //// Custom source dataframe tests //// @@ -132,7 +133,11 @@ impl ExecutionPlan for CustomExecutionPlan { )) } } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } @@ -243,7 +248,8 @@ async fn custom_source_dataframe() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let batches = collect(physical_plan).await?; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let batches = collect(physical_plan, runtime).await?; let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -289,7 +295,8 @@ async fn optimizers_catch_all_statistics() { ) .unwrap(); - let actual = collect(physical_plan).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let actual = collect(physical_plan, runtime).await.unwrap(); assert_eq!(actual.len(), 1); assert_eq!(format!("{:?}", actual[0]), format!("{:?}", expected)); diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index ee27a33f86f2..9abf3fd55a64 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -537,7 +537,8 @@ impl ContextWithParquet { .await .expect("creating physical plan"); - let results = datafusion::physical_plan::collect(physical_plan.clone()) + let runtime = self.ctx.state.lock().unwrap().runtime_env.clone(); + let results = datafusion::physical_plan::collect(physical_plan.clone(), runtime) .await .expect("Running"); diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index f1655c5267b3..330e95c6b037 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -22,6 +22,7 @@ use async_trait::async_trait; use datafusion::datasource::datasource::{TableProvider, TableProviderFilterPushDown}; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; use datafusion::physical_plan::{ @@ -78,7 +79,11 @@ impl ExecutionPlan for CustomPlan { unreachable!() } - async fn execute(&self, _: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { Ok(Box::pin(SizedRecordBatchStream::new( self.schema(), self.batches.clone(), diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index edf530be8b7d..c8586c6b47e9 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -25,7 +25,8 @@ async fn csv_query_avg_multi_batch() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.unwrap(); let batch = &results[0]; let column = batch.column(0); let array = column.as_any().downcast_ref::().unwrap(); diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index 3983389dae34..f3c0f0c525be 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -124,7 +124,8 @@ async fn avro_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(plan).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); assert_eq!(1, batch.num_columns()); diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index 9cd7bc96ff89..05ca0642bae0 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -37,7 +37,8 @@ async fn test_cast_expressions_error() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let result = collect(plan).await; + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let result = collect(plan, runtime).await; match result { Ok(_) => panic!("expected error"), diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index a9cef73521eb..128a0d82ab58 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -41,7 +41,8 @@ async fn explain_analyze_baseline_metrics() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let results = collect(physical_plan.clone()).await.unwrap(); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(physical_plan.clone(), runtime).await.unwrap(); let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() .to_string(); @@ -105,8 +106,9 @@ async fn explain_analyze_baseline_metrics() { fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { use datafusion::physical_plan; + use datafusion::physical_plan::sorts; - plan.as_any().downcast_ref::().is_some() + plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() // CoalescePartitionsExec doesn't do any work so is not included || plan.as_any().downcast_ref::().is_some() @@ -327,7 +329,8 @@ async fn csv_explain_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -524,7 +527,8 @@ async fn csv_explain_verbose_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let results = collect(plan).await.expect(&msg); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let results = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index 3cc129e73115..cd854c2ba41e 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -484,7 +484,8 @@ async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec Result<()> { let plan = ctx.create_physical_plan(&plan).await.expect(&msg); let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let res = collect(plan).await.expect(&msg); + let runtime = ctx.state.lock().unwrap().runtime_env.clone(); + let res = collect(plan, runtime).await.expect(&msg); let actual = result_vec(&res); let res1 = actual[0][0].as_str(); diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs index 2934d7889215..0e9771789bf8 100644 --- a/datafusion/tests/statistics.rs +++ b/datafusion/tests/statistics.rs @@ -33,6 +33,7 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion::execution::runtime_env::RuntimeEnv; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -144,7 +145,11 @@ impl ExecutionPlan for StatisticsValidation { } } - async fn execute(&self, _partition: usize) -> Result { + async fn execute( + &self, + _partition: usize, + _runtime: Arc, + ) -> Result { unimplemented!("This plan only serves for testing statistics") } diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index b603f6a87701..a8d19d6aeb6a 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -86,6 +86,7 @@ use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use async_trait::async_trait; use datafusion::execution::context::ExecutionProps; +use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::{Extension, Sort}; use datafusion::logical_plan::{DFSchemaRef, Limit}; @@ -455,7 +456,11 @@ impl ExecutionPlan for TopKExec { } /// Execute one partition and return an iterator over RecordBatch - async fn execute(&self, partition: usize) -> Result { + async fn execute( + &self, + partition: usize, + runtime: Arc, + ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( "TopKExec invalid partition {}", @@ -464,7 +469,7 @@ impl ExecutionPlan for TopKExec { } Ok(Box::pin(TopKReader { - input: self.input.execute(partition).await?, + input: self.input.execute(partition, runtime).await?, k: self.k, done: false, state: BTreeMap::new(),