Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ballista] Support Union in ballista. #2098

Merged
merged 7 commits into from
Mar 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,4 +561,64 @@ mod tests {
let df = context.sql(sql).await.unwrap();
assert!(!df.collect().await.unwrap().is_empty());
}

#[tokio::test]
#[cfg(feature = "standalone")]
async fn test_union_and_union_all() {
use super::*;
use ballista_core::config::{
BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
};
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::assert_batches_eq;
let config = BallistaConfigBuilder::default()
.set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
.build()
.unwrap();
let context = BallistaContext::standalone(&config, 1).await.unwrap();

let df = context
.sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;")
.await
.unwrap();
let res1 = df.collect().await.unwrap();
let expected1 = vec![
"+--------+",
"| number |",
"+--------+",
"| 1 |",
"+--------+",
];
assert_eq!(
expected1,
pretty_format_batches(&*res1)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
let expected2 = vec![
"+--------+",
"| number |",
"+--------+",
"| 1 |",
"| 1 |",
"+--------+",
];
let df = context
.sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;")
.await
.unwrap();
let res2 = df.collect().await.unwrap();
assert_eq!(
expected2,
pretty_format_batches(&*res2)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
}
}
10 changes: 10 additions & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ message LogicalPlanNode {
ValuesNode values = 16;
LogicalExtensionNode extension = 17;
CreateCatalogSchemaNode create_catalog_schema = 18;
UnionNode union = 19;
}
}

Expand Down Expand Up @@ -212,6 +213,10 @@ message JoinNode {
bool null_equals_null = 7;
}

message UnionNode {
repeated LogicalPlanNode inputs = 1;
}

message CrossJoinNode {
LogicalPlanNode left = 1;
LogicalPlanNode right = 2;
Expand Down Expand Up @@ -253,6 +258,7 @@ message PhysicalPlanNode {
CrossJoinExecNode cross_join = 19;
AvroScanExecNode avro_scan = 20;
PhysicalExtensionNode extension = 21;
UnionExecNode union = 22;
}
}

Expand Down Expand Up @@ -433,6 +439,10 @@ message HashJoinExecNode {
bool null_equals_null = 7;
}

message UnionExecNode {
repeated PhysicalPlanNode inputs = 1;
}

message CrossJoinExecNode {
PhysicalPlanNode left = 1;
PhysicalPlanNode right = 2;
Expand Down
37 changes: 36 additions & 1 deletion ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,25 @@ impl AsLogicalPlan for LogicalPlanNode {

builder.build().map_err(|e| e.into())
}
LogicalPlanType::Union(union) => {
let mut input_plans: Vec<LogicalPlan> = union
.inputs
.iter()
.map(|i| i.try_into_logical_plan(ctx, extension_codec))
.collect::<Result<_, BallistaError>>()?;

if input_plans.len() < 2 {
return Err( BallistaError::General(String::from(
"Protobuf deserialization error, Union was require at least two input.",
)));
}

let mut builder = LogicalPlanBuilder::from(input_plans.pop().unwrap());
for plan in input_plans {
builder = builder.union(plan)?;
}
builder.build().map_err(|e| e.into())
}
LogicalPlanType::CrossJoin(crossjoin) => {
let left = into_logical_plan!(crossjoin.left, &ctx, extension_codec)?;
let right = into_logical_plan!(crossjoin.right, &ctx, extension_codec)?;
Expand Down Expand Up @@ -815,7 +834,23 @@ impl AsLogicalPlan for LogicalPlanNode {
))),
})
}
LogicalPlan::Union(_) => unimplemented!(),
LogicalPlan::Union(union) => {
let inputs: Vec<LogicalPlanNode> = union
.inputs
.iter()
.map(|i| {
protobuf::LogicalPlanNode::try_from_logical_plan(
i,
extension_codec,
)
})
.collect::<Result<_, BallistaError>>()?;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Union(
protobuf::UnionNode { inputs },
)),
})
}
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
let left = protobuf::LogicalPlanNode::try_from_logical_plan(
left.as_ref(),
Expand Down
21 changes: 21 additions & 0 deletions ballista/rust/core/src/serde/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::union::UnionExec;
use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
use datafusion::physical_plan::{
AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
Expand Down Expand Up @@ -382,6 +383,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
&hashjoin.null_equals_null,
)?))
}
PhysicalPlanType::Union(union) => {
let mut inputs: Vec<Arc<dyn ExecutionPlan>> = vec![];
for input in &union.inputs {
inputs.push(input.try_into_physical_plan(ctx, extension_codec)?);
}
Ok(Arc::new(UnionExec::new(inputs)))
}
PhysicalPlanType::CrossJoin(crossjoin) => {
let left: Arc<dyn ExecutionPlan> =
into_physical_plan!(crossjoin.left, ctx, extension_codec)?;
Expand Down Expand Up @@ -866,6 +874,19 @@ impl AsExecutionPlan for PhysicalPlanNode {
},
)),
})
} else if let Some(union) = plan.downcast_ref::<UnionExec>() {
let mut inputs: Vec<PhysicalPlanNode> = vec![];
for input in union.inputs() {
inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
input.to_owned(),
extension_codec,
)?);
}
Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::Union(
protobuf::UnionExecNode { inputs },
)),
})
} else {
let mut buf: Vec<u8> = vec![];
extension_codec.try_encode(plan_clone.clone(), &mut buf)?;
Expand Down
5 changes: 5 additions & 0 deletions datafusion/src/physical_plan/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ impl UnionExec {
metrics: ExecutionPlanMetricsSet::new(),
}
}

/// Get inputs of the execution plan
pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
&self.inputs
}
}

#[async_trait]
Expand Down