Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into string-view
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 12, 2024
2 parents 2f0a7ec + 5ba634a commit 921afdf
Show file tree
Hide file tree
Showing 52 changed files with 2,635 additions and 496 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions datafusion-examples/examples/parse_sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> {

assert_eq!(sql, round_trip_sql);

// enable pretty-unparsing. This make the output more human-readable
// but can be problematic when passed to other SQL engines due to
// difference in precedence rules between DataFusion and target engines.
let unparser = Unparser::default().with_pretty(true);

let pretty = "int_col < 5 OR double_col = 8";
let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string();
assert_eq!(pretty, pretty_round_trip_sql);

Ok(())
}
18 changes: 15 additions & 3 deletions datafusion-examples/examples/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser};
/// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with
/// fluent API and convert to sql suitable for passing to another database
///
/// 2. [`simple_expr_to_sql_demo_no_escape`] Create a simple expression
/// [`Exprs`] with fluent API and convert to sql without escaping column names
/// more suitable for displaying to humans.
/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression
/// [`Exprs`] with fluent API and convert to sql without extra parentheses,
/// suitable for displaying to humans
///
/// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple
/// expression [`Exprs`] with fluent API and convert to sql escaping column
Expand All @@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser};
async fn main() -> Result<()> {
// See how to evaluate expressions
simple_expr_to_sql_demo()?;
simple_expr_to_pretty_sql_demo()?;
simple_expr_to_sql_demo_escape_mysql_style()?;
simple_plan_to_sql_demo().await?;
round_trip_plan_to_sql_demo().await?;
Expand All @@ -64,6 +65,17 @@ fn simple_expr_to_sql_demo() -> Result<()> {
Ok(())
}

/// DataFusioon can remove parentheses when converting an expression to SQL.
/// Note that output is intended for humans, not for other SQL engines,
/// as difference in precedence rules can cause expressions to be parsed differently.
fn simple_expr_to_pretty_sql_demo() -> Result<()> {
let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8)));
let unparser = Unparser::default().with_pretty(true);
let sql = unparser.expr_to_sql(&expr)?.to_string();
assert_eq!(sql, r#"a < 5 OR a = 8"#);
Ok(())
}

/// DataFusion can convert expressions to SQL without escaping column names using
/// using a custom dialect and an explicit unparser
fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> {
Expand Down
10 changes: 10 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1984,6 +1984,16 @@ impl ScalarValue {
Self::new_list(values, data_type, true)
}

/// Create ListArray with Null with specific data type
///
/// - new_null_list(i32, nullable, 1): `ListArray[NULL]`
pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) -> Self {
let data_type = DataType::List(Field::new_list_field(data_type, nullable).into());
Self::List(Arc::new(ListArray::from(ArrayData::new_null(
&data_type, null_len,
))))
}

/// Converts `IntoIterator<Item = ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
///
Expand Down
4 changes: 1 addition & 3 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,7 @@ impl SessionState {
);
}
let statement = statements.pop_front().ok_or_else(|| {
DataFusionError::NotImplemented(
"No SQL statements were provided in the query string".to_string(),
)
plan_datafusion_err!("No SQL statements were provided in the query string")
})?;
Ok(statement)
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,12 @@ doc_comment::doctest!(
user_guide_example_usage
);

#[cfg(doctest)]
doc_comment::doctest!(
"../../../docs/source/user-guide/crate-configuration.md",
user_guide_crate_configuration
);

#[cfg(doctest)]
doc_comment::doctest!(
"../../../docs/source/user-guide/configs.md",
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2312,7 +2312,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }";
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
Expand Down Expand Up @@ -2551,7 +2551,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, false),
false
true
),])
);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/sql_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async fn empty_statement_returns_error() {
let plan_res = state.create_logical_plan("").await;
assert_eq!(
plan_res.unwrap_err().strip_backtrace(),
"This feature is not implemented: No SQL statements were provided in the query string"
"Error during planning: No SQL statements were provided in the query string"
);
}

Expand Down
79 changes: 74 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions
use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
};
use arrow_schema::Schema;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
Expand All @@ -45,8 +50,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
LogicalPlanBuilder, SimpleAggregateUDF,
};
use datafusion_functions_aggregate::average::AvgAccumulator;

Expand Down Expand Up @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_parameterized_aggregate_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 1,
});
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 2,
});

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.aggregate(
[col("text")],
[
udf1.call(vec![col("text")]).alias("a"),
udf2.call(vec![col("text")]).alias("b"),
],
)?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+------+---+---+",
"| text | a | b |",
"+------+---+---+",
"| foo | 1 | 2 |",
"+------+---+---+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down Expand Up @@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
}

impl Accumulator for TestGroupsAccumulator {
Expand Down
Loading

0 comments on commit 921afdf

Please sign in to comment.