Skip to content

Commit

Permalink
remove duplicate the logic b/w DataFrame API and SQL planning
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Mar 29, 2023
1 parent 7b67d28 commit 423e604
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 59 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,9 @@ impl ExprSchema for DFSchema {
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DFField {
/// Optional qualifier (usually a table or relation name)
qualifier: Option<OwnedTableReference>,
pub qualifier: Option<OwnedTableReference>,
/// Arrow field definition
field: Field,
pub field: Field,
}

impl DFField {
Expand Down
130 changes: 107 additions & 23 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,93 @@ use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::test_util::parquet_test_data;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_common::ScalarValue;
use datafusion_common::ScalarValue::UInt64;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
use datafusion_expr::Expr::Wildcard;
use datafusion_expr::{
avg, col, count, expr, lit, max, sum, AggregateFunction, Expr, ExprSchemable,
Subquery, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction,
};

#[tokio::test]
async fn test_count_wildcard_on_where_exist() -> Result<()> {
let ctx = create_join_context()?;

let df_results = ctx
.table("t1")
.await?
.select(vec![col("a"), col("b")])?
.filter(Expr::Exists {
subquery: Subquery {
subquery: Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
.into_optimized_plan()?,
),
outer_ref_columns: vec![],
},
negated: false,
})?
.explain(false, false)?
.collect()
.await?;

#[rustfmt::skip]
let expected = vec![
"+--------------+-------------------------------------------------------+",
"| plan_type | plan |",
"+--------------+-------------------------------------------------------+",
"| logical_plan | Filter: EXISTS (<subquery>) |",
"| | Subquery: |",
"| | Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] |",
"| | TableScan: t2 projection=[a] |",
"| | TableScan: t1 projection=[a, b] |",
"+--------------+-------------------------------------------------------+",
];
assert_batches_eq!(expected, &df_results);
Ok(())
}

#[tokio::test]
async fn count_wildcard() -> Result<()> {
async fn test_count_wildcard_on_window() -> Result<()> {
let ctx = SessionContext::new();
let testdata = datafusion::test_util::parquet_test_data();

ctx.register_parquet(
"alltypes_tiny_pages",
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
.await?;
register_alltypes_tiny_pages_parquet(&ctx).await?;

let sql_results = ctx
.sql("select count(*) from alltypes_tiny_pages")
.sql("select COUNT(*) OVER(ORDER BY timestamp_col DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from alltypes_tiny_pages")
.await?
.select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
.collect()
.await?;

// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
.table("alltypes_tiny_pages")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
.select(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunction::AggregateFunction(AggregateFunction::Count),
vec![Expr::Wildcard],
vec![],
vec![Expr::Sort(Sort::new(
Box::new(col("timestamp_col")),
false,
true,
))],
WindowFrame {
units: WindowFrameUnits::Range,
start_bound: WindowFrameBound::Preceding(ScalarValue::IntervalDayTime(
Some(6),
)),
end_bound: WindowFrameBound::Following(ScalarValue::IntervalDayTime(
Some(2),
)),
},
))])?
.explain(false, false)?
.collect()
.await?;
Expand All @@ -72,21 +129,37 @@ async fn count_wildcard() -> Result<()> {
pretty_format_batches(&df_results)?.to_string()
);

let results = ctx
Ok(())
}

#[tokio::test]
async fn test_count_wildcard_on_aggregate() -> Result<()> {
let ctx = SessionContext::new();
register_alltypes_tiny_pages_parquet(&ctx).await?;

let sql_results = ctx
.sql("select count(*) from alltypes_tiny_pages")
.await?
.select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
.collect()
.await?;

// add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
.table("alltypes_tiny_pages")
.await?
.aggregate(vec![], vec![count(Expr::Wildcard)])?
.select(vec![count(Expr::Wildcard)])?
.explain(false, false)?
.collect()
.await?;

let expected = vec![
"+-----------------+",
"| COUNT(UInt8(1)) |",
"+-----------------+",
"| 7300 |",
"+-----------------+",
];
assert_batches_sorted_eq!(expected, &results);
//make sure sql plan same with df plan
assert_eq!(
pretty_format_batches(&sql_results)?.to_string(),
pretty_format_batches(&df_results)?.to_string()
);

Ok(())
}
Expand Down Expand Up @@ -1047,3 +1120,14 @@ async fn table_with_nested_types(n: usize) -> Result<DataFrame> {
ctx.register_batch("shapes", batch)?;
ctx.table("shapes").await
}

pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> {
let testdata = parquet_test_data();
ctx.register_parquet(
"alltypes_tiny_pages",
&format!("{testdata}/alltypes_tiny_pages.parquet"),
ParquetReadOptions::default(),
)
.await?;
Ok(())
}
9 changes: 0 additions & 9 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1161,15 +1161,6 @@ async fn try_execute_to_batches(
/// Execute query and return results as a Vec of RecordBatches
async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> {
let df = ctx.sql(sql).await.unwrap();

// We are not really interested in the direct output of optimized_logical_plan
// since the physical plan construction already optimizes the given logical plan
// and we want to avoid double-optimization as a consequence. So we just construct
// it here to make sure that it doesn't fail at this step and get the optimized
// schema (to assert later that the logical and optimized schemas are the same).
let optimized = df.clone().into_optimized_plan().unwrap();
assert_eq!(df.logical_plan().schema(), optimized.schema());

df.collect().await.unwrap()
}

Expand Down
Loading

0 comments on commit 423e604

Please sign in to comment.