Skip to content

Commit

Permalink
Add support for UNION sql
Browse files Browse the repository at this point in the history
  • Loading branch information
xudong963 committed Oct 4, 2021
1 parent 2f04d67 commit 332e472
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 2 deletions.
15 changes: 15 additions & 0 deletions datafusion/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ pub trait DataFrame: Send + Sync {
/// ```
fn union(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn DataFrame>>;

/// Calculate the union distinct two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # fn main() -> Result<()> {
/// let mut ctx = ExecutionContext::new();
/// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new())?;
/// let df = df.union(df.clone())?;
/// let df = df.distinct()?;
/// # Ok(())
/// # }
/// ```
fn distinct(&self) -> Result<Arc<dyn DataFrame>>;

/// Sort the DataFrame by the specified sorting expressions. Any expression can be turned into
/// a sort expression by calling its [sort](../logical_plan/enum.Expr.html#method.sort) method.
///
Expand Down
9 changes: 9 additions & 0 deletions datafusion/src/execution/dataframe_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,15 @@ impl DataFrame for DataFrameImpl {
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}

fn distinct(&self) -> Result<Arc<dyn DataFrame>> {
Ok(Arc::new(DataFrameImpl::new(
self.ctx_state.clone(),
&LogicalPlanBuilder::from(self.to_logical_plan())
.union_distinct()?
.build()?,
)))
}
}

#[cfg(test)]
Expand Down
9 changes: 9 additions & 0 deletions datafusion/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ impl LogicalPlanBuilder {
Ok(Self::from(union_with_alias(self.plan.clone(), plan, None)?))
}

/// apply union distinct
pub fn union_distinct(&self) -> Result<Self> {
let projection_expr = expand_wildcard(self.plan.schema(), &self.plan)?;
let plan = LogicalPlanBuilder::from(self.plan.clone())
.aggregate(projection_expr, vec![])?
.build()?;
Self::from(plan).project(vec![Expr::Wildcard])
}

/// Apply a join with on constraint
pub fn join(
&self,
Expand Down
12 changes: 10 additions & 2 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
union_with_alias(left_plan, right_plan, alias)
}
(SetOperator::Union, false) => {
let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
let union_plan = union_with_alias(left_plan, right_plan, alias)?;
LogicalPlanBuilder::from(union_plan)
.union_distinct()?
.build()
}
_ => Err(DataFusionError::NotImplemented(format!(
"Only UNION ALL is supported, found {}",
"Only UNION ALL and UNION [DISTINCT] are supported, found {}",
op
))),
},
Expand Down Expand Up @@ -3440,7 +3448,7 @@ mod tests {
let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
"NotImplemented(\"Only UNION ALL is supported, found EXCEPT\")",
"NotImplemented(\"Only UNION ALL and UNION [DISTINCT] are supported, found EXCEPT\")",
format!("{:?}", err)
);
}
Expand Down
39 changes: 39 additions & 0 deletions datafusion/tests/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ const QUERY1: &str = "SELECT * FROM sales limit 3";
const QUERY: &str =
"SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3";

const QUERY2: &str = "SELECT customer_id, revenue FROM sales UNION SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3";

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query(
Expand Down Expand Up @@ -188,6 +190,43 @@ async fn run_and_compare_query_with_auto_schemas(
Ok(())
}

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_union_query(
mut ctx: ExecutionContext,
description: &str,
) -> Result<()> {
let expected = vec![
"+-------------+---------+",
"| customer_id | revenue |",
"+-------------+---------+",
"| paul | 300 |",
"| jorge | 200 |",
"| andy | 150 |",
"+-------------+---------+",
];

let s = exec_sql(&mut ctx, QUERY2).await?;
let actual = s.lines().collect::<Vec<_>>();

assert_eq!(
expected,
actual,
"output mismatch for {}. Expectedn\n{}Actual:\n{}",
description,
expected.join("\n"),
s
);
Ok(())
}

#[tokio::test]
// Run the query to test union
async fn query_to_test_union() -> Result<()> {
let ctx = setup_table(ExecutionContext::new()).await?;
run_and_compare_union_query(ctx, "Default context").await
}

#[tokio::test]
// Run the query using default planners and optimizer
async fn normal_query_without_schemas() -> Result<()> {
Expand Down

0 comments on commit 332e472

Please sign in to comment.