Skip to content

Commit

Permalink
Make ExecutionPlan sync (apache#2307)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed May 4, 2022
1 parent b7bb2cf commit 2ae8cd8
Show file tree
Hide file tree
Showing 45 changed files with 160 additions and 259 deletions.
4 changes: 1 addition & 3 deletions ballista/rust/core/src/execution_plans/distributed_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ use datafusion::physical_plan::{

use crate::serde::protobuf::execute_query_params::OptionalSessionId;
use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec};
use async_trait::async_trait;
use datafusion::arrow::error::{ArrowError, Result as ArrowResult};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::execution::context::TaskContext;
Expand Down Expand Up @@ -122,7 +121,6 @@ impl<T: 'static + AsLogicalPlan> DistributedQueryExec<T> {
}
}

#[async_trait]
impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -162,7 +160,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
}))
}

async fn execute(
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
Expand Down
4 changes: 1 addition & 3 deletions ballista/rust/core/src/execution_plans/shuffle_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::sync::Arc;
use crate::client::BallistaClient;
use crate::serde::scheduler::{PartitionLocation, PartitionStats};

use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;

use datafusion::error::{DataFusionError, Result};
Expand Down Expand Up @@ -64,7 +63,6 @@ impl ShuffleReaderExec {
}
}

#[async_trait]
impl ExecutionPlan for ShuffleReaderExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -101,7 +99,7 @@ impl ExecutionPlan for ShuffleReaderExec {
))
}

async fn execute(
fn execute(
&self,
partition: usize,
_context: Arc<TaskContext>,
Expand Down
10 changes: 4 additions & 6 deletions ballista/rust/core/src/execution_plans/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use crate::utils;

use crate::serde::protobuf::ShuffleWritePartition;
use crate::serde::scheduler::PartitionStats;
use async_trait::async_trait;
use datafusion::arrow::array::{
ArrayBuilder, ArrayRef, StringBuilder, StructBuilder, UInt32Builder, UInt64Builder,
};
Expand Down Expand Up @@ -155,7 +154,7 @@ impl ShuffleWriterExec {

async move {
let now = Instant::now();
let mut stream = plan.execute(input_partition, context).await?;
let mut stream = plan.execute(input_partition, context)?;

match output_partitioning {
None => {
Expand Down Expand Up @@ -293,7 +292,6 @@ impl ShuffleWriterExec {
}
}

#[async_trait]
impl ExecutionPlan for ShuffleWriterExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -336,7 +334,7 @@ impl ExecutionPlan for ShuffleWriterExec {
)?))
}

async fn execute(
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
Expand Down Expand Up @@ -459,7 +457,7 @@ mod tests {
work_dir.into_path().to_str().unwrap().to_owned(),
Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)),
)?;
let mut stream = query_stage.execute(0, task_ctx).await?;
let mut stream = query_stage.execute(0, task_ctx)?;
let batches = utils::collect_stream(&mut stream)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
Expand Down Expand Up @@ -516,7 +514,7 @@ mod tests {
work_dir.into_path().to_str().unwrap().to_owned(),
Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)),
)?;
let mut stream = query_stage.execute(0, task_ctx).await?;
let mut stream = query_stage.execute(0, task_ctx)?;
let batches = utils::collect_stream(&mut stream)
.await
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
Expand Down
4 changes: 1 addition & 3 deletions ballista/rust/core/src/execution_plans/unresolved_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
Expand Down Expand Up @@ -63,7 +62,6 @@ impl UnresolvedShuffleExec {
}
}

#[async_trait]
impl ExecutionPlan for UnresolvedShuffleExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -101,7 +99,7 @@ impl ExecutionPlan for UnresolvedShuffleExec {
))
}

async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand Down
3 changes: 1 addition & 2 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ mod tests {
}
}

#[async_trait]
impl ExecutionPlan for TopKExec {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -515,7 +514,7 @@ mod tests {
}

/// Execute one partition and return an iterator over RecordBatch
async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand Down
10 changes: 3 additions & 7 deletions ballista/rust/executor/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, pin::Pin};

use async_trait::async_trait;
use datafusion::arrow::{
datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch,
};
Expand All @@ -49,7 +48,6 @@ impl CollectExec {
}
}

#[async_trait]
impl ExecutionPlan for CollectExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -78,18 +76,16 @@ impl ExecutionPlan for CollectExec {
unimplemented!()
}

async fn execute(
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq!(0, partition);
let num_partitions = self.plan.output_partitioning().partition_count();

let futures = (0..num_partitions).map(|i| self.plan.execute(i, context.clone()));
let streams = futures::future::join_all(futures)
.await
.into_iter()
let streams = (0..num_partitions)
.map(|i| self.plan.execute(i, context.clone()))
.collect::<Result<Vec<_>>>()
.map_err(|e| DataFusionError::Execution(format!("BallistaError: {:?}", e)))?;

Expand Down
7 changes: 3 additions & 4 deletions datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ impl CustomExec {
}
}

#[async_trait]
impl ExecutionPlan for CustomExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -225,7 +224,7 @@ impl ExecutionPlan for CustomExec {
Ok(self)
}

async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand All @@ -243,7 +242,7 @@ impl ExecutionPlan for CustomExec {
account_array.append_value(user.bank_account)?;
}

return Ok(Box::pin(MemoryStream::try_new(
Ok(Box::pin(MemoryStream::try_new(
vec![RecordBatch::try_new(
self.projected_schema.clone(),
vec![
Expand All @@ -253,7 +252,7 @@ impl ExecutionPlan for CustomExec {
)?],
self.schema(),
None,
)?));
)?))
}

fn statistics(&self) -> Statistics {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ mod tests {
let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]);
let exec = get_exec("aggregate_test_100.csv", &projection, None).await?;
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
.map(|batch| {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod tests {
let projection = None;
let exec = get_exec(&projection, None).await?;
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches: i32 = stream
.map(|batch| {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ mod tests {
let projection = None;
let exec = get_exec("alltypes_plain.parquet", &projection, None).await?;
let task_ctx = ctx.task_ctx();
let stream = exec.execute(0, task_ctx).await?;
let stream = exec.execute(0, task_ctx)?;

let tt_batches = stream
.map(|batch| {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl MemTable {
let context1 = context.clone();
let exec = exec.clone();
tokio::spawn(async move {
let stream = exec.execute(part_i, context1.clone()).await?;
let stream = exec.execute(part_i, context1.clone())?;
common::collect(stream).await
})
})
Expand All @@ -103,7 +103,7 @@ impl MemTable {
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
let mut stream = exec.execute(i, context.clone()).await?;
let mut stream = exec.execute(i, context.clone())?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
Expand Down Expand Up @@ -177,7 +177,7 @@ mod tests {

// scan with projection
let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?;
let mut it = exec.execute(0, task_ctx).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
assert_eq!("c", batch2.schema().field(0).name());
Expand Down Expand Up @@ -209,7 +209,7 @@ mod tests {
let provider = MemTable::try_new(schema, vec![vec![batch]])?;

let exec = provider.scan(&None, &[], None).await?;
let mut it = exec.execute(0, task_ctx).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Expand Down Expand Up @@ -365,7 +365,7 @@ mod tests {
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;

let exec = provider.scan(&None, &[], None).await?;
let mut it = exec.execute(0, task_ctx).await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
assert_eq!(3, batch1.num_columns());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ mod tests {

// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
let result = common::collect(optimized.execute(0, task_ctx).await?).await?;
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
assert_eq!(
result[0]
Expand Down
15 changes: 5 additions & 10 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::physical_plan::{
};
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::Result;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -145,7 +144,6 @@ impl AggregateExec {
}
}

#[async_trait]
impl ExecutionPlan for AggregateExec {
/// Return a reference to Any that can be used for down-casting
fn as_any(&self) -> &dyn Any {
Expand Down Expand Up @@ -196,12 +194,12 @@ impl ExecutionPlan for AggregateExec {
)?))
}

async fn execute(
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, context).await?;
let input = self.input.execute(partition, context)?;
let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
Expand Down Expand Up @@ -417,7 +415,6 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_common::{DataFusionError, Result};
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
use futures::{FutureExt, Stream};
Expand Down Expand Up @@ -489,8 +486,7 @@ mod tests {
)?);

let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone()).await?)
.await?;
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;

let expected = vec![
"+---+---------------+-------------+",
Expand Down Expand Up @@ -522,7 +518,7 @@ mod tests {
)?);

let result =
common::collect(merged_aggregate.execute(0, task_ctx.clone()).await?).await?;
common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?;
assert_eq!(result.len(), 1);

let batch = &result[0];
Expand Down Expand Up @@ -556,7 +552,6 @@ mod tests {
pub yield_first: bool,
}

#[async_trait]
impl ExecutionPlan for TestYieldingExec {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -587,7 +582,7 @@ mod tests {
)))
}

async fn execute(
fn execute(
&self,
_partition: usize,
_context: Arc<TaskContext>,
Expand Down
Loading

0 comments on commit 2ae8cd8

Please sign in to comment.