Skip to content

Commit

Permalink
Consolidate Schema and RecordBatch projection (#1638)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored Jan 23, 2022
1 parent 01b5244 commit 6ec18bb
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 77 deletions.
13 changes: 6 additions & 7 deletions ballista/rust/core/src/memory_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,12 @@ impl Stream for MemoryStream {
let batch = &self.data[self.index - 1];

// apply projection
match &self.projection {
Some(columns) => Some(RecordBatch::try_new(
self.schema.clone(),
columns.iter().map(|i| batch.column(*i).clone()).collect(),
)),
None => Some(Ok(batch.clone())),
}
let next_batch = match &self.projection {
Some(projection) => batch.project(projection)?,
None => batch.clone(),
};

Some(Ok(next_batch))
} else {
None
})
Expand Down
14 changes: 3 additions & 11 deletions datafusion/src/datasource/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use async_trait::async_trait;
use crate::datasource::TableProvider;
use crate::error::Result;
use crate::logical_plan::Expr;
use crate::physical_plan::project_schema;
use crate::physical_plan::{empty::EmptyExec, ExecutionPlan};

/// A table with a schema but no data.
Expand Down Expand Up @@ -57,16 +58,7 @@ impl TableProvider for EmptyTable {
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
// even though there is no data, projections apply
let projection = match projection.clone() {
Some(p) => p,
None => (0..self.schema.fields().len()).collect(),
};
let projected_schema = Schema::new(
projection
.iter()
.map(|i| self.schema.field(*i).clone())
.collect(),
);
Ok(Arc::new(EmptyExec::new(false, Arc::new(projected_schema))))
let projected_schema = project_schema(&self.schema, projection.as_ref())?;
Ok(Arc::new(EmptyExec::new(false, projected_schema)))
}
}
9 changes: 2 additions & 7 deletions datafusion/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::{
physical_plan::{
empty::EmptyExec,
file_format::{FileScanConfig, DEFAULT_PARTITION_COLUMN_DATATYPE},
ExecutionPlan, Statistics,
project_schema, ExecutionPlan, Statistics,
},
};

Expand Down Expand Up @@ -179,12 +179,7 @@ impl TableProvider for ListingTable {
// if no files need to be read, return an `EmptyExec`
if partitioned_file_lists.is_empty() {
let schema = self.schema();
let projected_schema = match &projection {
None => schema,
Some(p) => Arc::new(Schema::new(
p.iter().map(|i| schema.field(*i).clone()).collect(),
)),
};
let projected_schema = project_schema(&schema, projection.as_ref())?;
return Ok(Arc::new(EmptyExec::new(false, projected_schema)));
}

Expand Down
10 changes: 7 additions & 3 deletions datafusion/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ mod tests {
use crate::from_slice::FromSlice;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use futures::StreamExt;
use std::collections::HashMap;

Expand Down Expand Up @@ -235,10 +236,13 @@ mod tests {
let projection: Vec<usize> = vec![0, 4];

match provider.scan(&Some(projection), &[], None).await {
Err(DataFusionError::Internal(e)) => {
assert_eq!("\"Projection index out of range\"", format!("{:?}", e))
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
format!("{:?}", e)
)
}
_ => panic!("Scan should failed on invalid projection"),
res => panic!("Scan should failed on invalid projection, got {:?}", res),
};

Ok(())
Expand Down
41 changes: 12 additions & 29 deletions datafusion/src/physical_plan/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ use std::sync::Arc;
use std::task::{Context, Poll};

use super::{
common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream, Statistics,
common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use crate::error::{DataFusionError, Result};
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::datatypes::SchemaRef;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;

Expand Down Expand Up @@ -136,24 +136,7 @@ impl MemoryExec {
schema: SchemaRef,
projection: Option<Vec<usize>>,
) -> Result<Self> {
let projected_schema = match &projection {
Some(columns) => {
let fields: Result<Vec<Field>> = columns
.iter()
.map(|i| {
if *i < schema.fields().len() {
Ok(schema.field(*i).clone())
} else {
Err(DataFusionError::Internal(
"Projection index out of range".to_string(),
))
}
})
.collect();
Arc::new(Schema::new(fields?))
}
None => Arc::clone(&schema),
};
let projected_schema = project_schema(&schema, projection.as_ref())?;
Ok(Self {
partitions: partitions.to_vec(),
schema,
Expand Down Expand Up @@ -201,14 +184,14 @@ impl Stream for MemoryStream {
Poll::Ready(if self.index < self.data.len() {
self.index += 1;
let batch = &self.data[self.index - 1];
// apply projection
match &self.projection {
Some(columns) => Some(RecordBatch::try_new(
self.schema.clone(),
columns.iter().map(|i| batch.column(*i).clone()).collect(),
)),
None => Some(Ok(batch.clone())),
}

// return just the columns requested
let batch = match self.projection.as_ref() {
Some(columns) => batch.project(columns)?,
None => batch.clone(),
};

Some(Ok(batch))
} else {
None
})
Expand Down
40 changes: 40 additions & 0 deletions datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,46 @@ pub trait Accumulator: Send + Sync + Debug {
fn evaluate(&self) -> Result<ScalarValue>;
}

/// Applies an optional projection to a [`SchemaRef`], returning the
/// projected schema
///
/// Example:
/// ```
/// use arrow::datatypes::{SchemaRef, Schema, Field, DataType};
/// use datafusion::physical_plan::project_schema;
///
/// // Schema with columns 'a', 'b', and 'c'
/// let schema = SchemaRef::new(Schema::new(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Int64, true),
/// Field::new("c", DataType::Utf8, true),
/// ]));
///
/// // Pick columns 'c' and 'b'
/// let projection = Some(vec![2,1]);
/// let projected_schema = project_schema(
/// &schema,
/// projection.as_ref()
/// ).unwrap();
///
/// let expected_schema = SchemaRef::new(Schema::new(vec![
/// Field::new("c", DataType::Utf8, true),
/// Field::new("b", DataType::Int64, true),
/// ]));
///
/// assert_eq!(projected_schema, expected_schema);
/// ```
pub fn project_schema(
schema: &SchemaRef,
projection: Option<&Vec<usize>>,
) -> Result<SchemaRef> {
let schema = match projection {
Some(columns) => Arc::new(schema.project(columns)?),
None => Arc::clone(schema),
};
Ok(schema)
}

pub mod aggregates;
pub mod analyze;
pub mod array_expressions;
Expand Down
9 changes: 2 additions & 7 deletions datafusion/tests/custom_sources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use datafusion::logical_plan::{
col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE,
};
use datafusion::physical_plan::{
ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream,
project_schema, ColumnStatistics, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};

Expand Down Expand Up @@ -108,12 +108,7 @@ impl ExecutionPlan for CustomExecutionPlan {
}
fn schema(&self) -> SchemaRef {
let schema = TEST_CUSTOM_SCHEMA_REF!();
match &self.projection {
None => schema,
Some(p) => Arc::new(Schema::new(
p.iter().map(|i| schema.field(*i).clone()).collect(),
)),
}
project_schema(&schema, self.projection.as_ref()).expect("projected schema")
}
fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
Expand Down
18 changes: 5 additions & 13 deletions datafusion/tests/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion::{
error::{DataFusionError, Result},
logical_plan::Expr,
physical_plan::{
ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning,
project_schema, ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning,
SendableRecordBatchStream, Statistics,
},
prelude::ExecutionContext,
Expand All @@ -44,7 +44,7 @@ struct StatisticsValidation {
}

impl StatisticsValidation {
fn new(stats: Statistics, schema: Schema) -> Self {
fn new(stats: Statistics, schema: SchemaRef) -> Self {
assert!(
stats
.column_statistics
Expand All @@ -53,10 +53,7 @@ impl StatisticsValidation {
.unwrap_or(true),
"if defined, the column statistics vector length should be the number of fields"
);
Self {
stats,
schema: Arc::new(schema),
}
Self { stats, schema }
}
}

Expand Down Expand Up @@ -87,12 +84,7 @@ impl TableProvider for StatisticsValidation {
Some(p) => p,
None => (0..self.schema.fields().len()).collect(),
};
let projected_schema = Schema::new(
projection
.iter()
.map(|i| self.schema.field(*i).clone())
.collect(),
);
let projected_schema = project_schema(&self.schema, Some(&projection))?;

let current_stat = self.stats.clone();

Expand Down Expand Up @@ -177,7 +169,7 @@ impl ExecutionPlan for StatisticsValidation {
fn init_ctx(stats: Statistics, schema: Schema) -> Result<ExecutionContext> {
let mut ctx = ExecutionContext::new();
let provider: Arc<dyn TableProvider> =
Arc::new(StatisticsValidation::new(stats, schema));
Arc::new(StatisticsValidation::new(stats, Arc::new(schema)));
ctx.register_table("stats_table", provider)?;
Ok(ctx)
}
Expand Down

0 comments on commit 6ec18bb

Please sign in to comment.