From 6ec18bb4a53f684efd8d97443c55035eb37bda14 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 23 Jan 2022 14:14:04 -0500 Subject: [PATCH] Consolidate Schema and RecordBatch projection (#1638) --- ballista/rust/core/src/memory_stream.rs | 13 ++++--- datafusion/src/datasource/empty.rs | 14 ++------ datafusion/src/datasource/listing/table.rs | 9 ++--- datafusion/src/datasource/memory.rs | 10 ++++-- datafusion/src/physical_plan/memory.rs | 41 +++++++--------------- datafusion/src/physical_plan/mod.rs | 40 +++++++++++++++++++++ datafusion/tests/custom_sources.rs | 9 ++--- datafusion/tests/statistics.rs | 18 +++------- 8 files changed, 77 insertions(+), 77 deletions(-) diff --git a/ballista/rust/core/src/memory_stream.rs b/ballista/rust/core/src/memory_stream.rs index ab72bdc82aee..0c0ba4b4a88b 100644 --- a/ballista/rust/core/src/memory_stream.rs +++ b/ballista/rust/core/src/memory_stream.rs @@ -67,13 +67,12 @@ impl Stream for MemoryStream { let batch = &self.data[self.index - 1]; // apply projection - match &self.projection { - Some(columns) => Some(RecordBatch::try_new( - self.schema.clone(), - columns.iter().map(|i| batch.column(*i).clone()).collect(), - )), - None => Some(Ok(batch.clone())), - } + let next_batch = match &self.projection { + Some(projection) => batch.project(projection)?, + None => batch.clone(), + }; + + Some(Ok(next_batch)) } else { None }) diff --git a/datafusion/src/datasource/empty.rs b/datafusion/src/datasource/empty.rs index 966560518cb2..5622d15a0d67 100644 --- a/datafusion/src/datasource/empty.rs +++ b/datafusion/src/datasource/empty.rs @@ -26,6 +26,7 @@ use async_trait::async_trait; use crate::datasource::TableProvider; use crate::error::Result; use crate::logical_plan::Expr; +use crate::physical_plan::project_schema; use crate::physical_plan::{empty::EmptyExec, ExecutionPlan}; /// A table with a schema but no data. @@ -57,16 +58,7 @@ impl TableProvider for EmptyTable { _limit: Option, ) -> Result> { // even though there is no data, projections apply - let projection = match projection.clone() { - Some(p) => p, - None => (0..self.schema.fields().len()).collect(), - }; - let projected_schema = Schema::new( - projection - .iter() - .map(|i| self.schema.field(*i).clone()) - .collect(), - ); - Ok(Arc::new(EmptyExec::new(false, Arc::new(projected_schema)))) + let projected_schema = project_schema(&self.schema, projection.as_ref())?; + Ok(Arc::new(EmptyExec::new(false, projected_schema))) } } diff --git a/datafusion/src/datasource/listing/table.rs b/datafusion/src/datasource/listing/table.rs index ff6d32210661..2f8f70f5ede5 100644 --- a/datafusion/src/datasource/listing/table.rs +++ b/datafusion/src/datasource/listing/table.rs @@ -29,7 +29,7 @@ use crate::{ physical_plan::{ empty::EmptyExec, file_format::{FileScanConfig, DEFAULT_PARTITION_COLUMN_DATATYPE}, - ExecutionPlan, Statistics, + project_schema, ExecutionPlan, Statistics, }, }; @@ -179,12 +179,7 @@ impl TableProvider for ListingTable { // if no files need to be read, return an `EmptyExec` if partitioned_file_lists.is_empty() { let schema = self.schema(); - let projected_schema = match &projection { - None => schema, - Some(p) => Arc::new(Schema::new( - p.iter().map(|i| schema.field(*i).clone()).collect(), - )), - }; + let projected_schema = project_schema(&schema, projection.as_ref())?; return Ok(Arc::new(EmptyExec::new(false, projected_schema))); } diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index c732b17d0e33..5fad702672ef 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -147,6 +147,7 @@ mod tests { use crate::from_slice::FromSlice; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; + use arrow::error::ArrowError; use futures::StreamExt; use std::collections::HashMap; @@ -235,10 +236,13 @@ mod tests { let projection: Vec = vec![0, 4]; match provider.scan(&Some(projection), &[], None).await { - Err(DataFusionError::Internal(e)) => { - assert_eq!("\"Projection index out of range\"", format!("{:?}", e)) + Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { + assert_eq!( + "\"project index 4 out of bounds, max field 3\"", + format!("{:?}", e) + ) } - _ => panic!("Scan should failed on invalid projection"), + res => panic!("Scan should failed on invalid projection, got {:?}", res), }; Ok(()) diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index 61be207720ee..8e32b097630f 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -23,11 +23,11 @@ use std::sync::Arc; use std::task::{Context, Poll}; use super::{ - common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, + common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, }; use crate::error::{DataFusionError, Result}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; @@ -136,24 +136,7 @@ impl MemoryExec { schema: SchemaRef, projection: Option>, ) -> Result { - let projected_schema = match &projection { - Some(columns) => { - let fields: Result> = columns - .iter() - .map(|i| { - if *i < schema.fields().len() { - Ok(schema.field(*i).clone()) - } else { - Err(DataFusionError::Internal( - "Projection index out of range".to_string(), - )) - } - }) - .collect(); - Arc::new(Schema::new(fields?)) - } - None => Arc::clone(&schema), - }; + let projected_schema = project_schema(&schema, projection.as_ref())?; Ok(Self { partitions: partitions.to_vec(), schema, @@ -201,14 +184,14 @@ impl Stream for MemoryStream { Poll::Ready(if self.index < self.data.len() { self.index += 1; let batch = &self.data[self.index - 1]; - // apply projection - match &self.projection { - Some(columns) => Some(RecordBatch::try_new( - self.schema.clone(), - columns.iter().map(|i| batch.column(*i).clone()).collect(), - )), - None => Some(Ok(batch.clone())), - } + + // return just the columns requested + let batch = match self.projection.as_ref() { + Some(columns) => batch.project(columns)?, + None => batch.clone(), + }; + + Some(Ok(batch)) } else { None }) diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ce127224c53e..216d4a65e639 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -586,6 +586,46 @@ pub trait Accumulator: Send + Sync + Debug { fn evaluate(&self) -> Result; } +/// Applies an optional projection to a [`SchemaRef`], returning the +/// projected schema +/// +/// Example: +/// ``` +/// use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; +/// use datafusion::physical_plan::project_schema; +/// +/// // Schema with columns 'a', 'b', and 'c' +/// let schema = SchemaRef::new(Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Int64, true), +/// Field::new("c", DataType::Utf8, true), +/// ])); +/// +/// // Pick columns 'c' and 'b' +/// let projection = Some(vec![2,1]); +/// let projected_schema = project_schema( +/// &schema, +/// projection.as_ref() +/// ).unwrap(); +/// +/// let expected_schema = SchemaRef::new(Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Int64, true), +/// ])); +/// +/// assert_eq!(projected_schema, expected_schema); +/// ``` +pub fn project_schema( + schema: &SchemaRef, + projection: Option<&Vec>, +) -> Result { + let schema = match projection { + Some(columns) => Arc::new(schema.project(columns)?), + None => Arc::clone(schema), + }; + Ok(schema) +} + pub mod aggregates; pub mod analyze; pub mod array_expressions; diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index c2511bac667c..e069dd750c18 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -34,7 +34,7 @@ use datafusion::logical_plan::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; use datafusion::physical_plan::{ - ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream, + project_schema, ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; @@ -108,12 +108,7 @@ impl ExecutionPlan for CustomExecutionPlan { } fn schema(&self) -> SchemaRef { let schema = TEST_CUSTOM_SCHEMA_REF!(); - match &self.projection { - None => schema, - Some(p) => Arc::new(Schema::new( - p.iter().map(|i| schema.field(*i).clone()).collect(), - )), - } + project_schema(&schema, self.projection.as_ref()).expect("projected schema") } fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(1) diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs index 4964baff8f28..3bc3720c670e 100644 --- a/datafusion/tests/statistics.rs +++ b/datafusion/tests/statistics.rs @@ -25,7 +25,7 @@ use datafusion::{ error::{DataFusionError, Result}, logical_plan::Expr, physical_plan::{ - ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, + project_schema, ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, prelude::ExecutionContext, @@ -44,7 +44,7 @@ struct StatisticsValidation { } impl StatisticsValidation { - fn new(stats: Statistics, schema: Schema) -> Self { + fn new(stats: Statistics, schema: SchemaRef) -> Self { assert!( stats .column_statistics @@ -53,10 +53,7 @@ impl StatisticsValidation { .unwrap_or(true), "if defined, the column statistics vector length should be the number of fields" ); - Self { - stats, - schema: Arc::new(schema), - } + Self { stats, schema } } } @@ -87,12 +84,7 @@ impl TableProvider for StatisticsValidation { Some(p) => p, None => (0..self.schema.fields().len()).collect(), }; - let projected_schema = Schema::new( - projection - .iter() - .map(|i| self.schema.field(*i).clone()) - .collect(), - ); + let projected_schema = project_schema(&self.schema, Some(&projection))?; let current_stat = self.stats.clone(); @@ -177,7 +169,7 @@ impl ExecutionPlan for StatisticsValidation { fn init_ctx(stats: Statistics, schema: Schema) -> Result { let mut ctx = ExecutionContext::new(); let provider: Arc = - Arc::new(StatisticsValidation::new(stats, schema)); + Arc::new(StatisticsValidation::new(stats, Arc::new(schema))); ctx.register_table("stats_table", provider)?; Ok(ctx) }