Skip to content

Commit

Permalink
Merge branch 'main' into physicalexpr-cse
Browse files Browse the repository at this point in the history
# Conflicts:
#	datafusion/common/src/cse.rs
  • Loading branch information
peter-toth committed Oct 24, 2024
2 parents d2529ce + 8adbc23 commit 984c6ee
Show file tree
Hide file tree
Showing 123 changed files with 3,204 additions and 2,394 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 34 additions & 17 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,19 @@ impl<'n, N: HashNode> Identifier<'n, N> {
/// ```
type IdArray<'n, N> = Vec<(usize, Option<Identifier<'n, N>>)>;

/// A map that contains the number of normal and conditional occurrences of [`TreeNode`]s
/// by their identifiers. It also contains the position of a [`TreeNode`] in
/// [`CommonNodes`] once a node is found to be common and got extracted.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (usize, usize, Option<usize>)>;
#[derive(PartialEq, Eq)]
/// How many times a node is evaluated. A node can be considered common if evaluated
/// surely at least 2 times or surely only once but also conditionally.
enum NodeEvaluation {
SurelyOnce,
ConditionallyAtLeastOnce,
Common,
}

/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is
/// found to be common and got extracted.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (NodeEvaluation, Option<usize>)>;

/// A list that contains the common [`TreeNode`]s and their alias, extracted during the
/// second, rewriting traversal.
Expand Down Expand Up @@ -331,16 +340,25 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.controller.is_ignored(node) {
self.id_array[down_index].1 = Some(node_id);
let (count, conditional_count, _) =
self.node_stats.entry(node_id).or_insert((0, 0, None));
if self.conditional {
*conditional_count += 1;
} else {
*count += 1;
}
if *count > 1 || (*count == 1 && *conditional_count > 0) {
self.found_common = true;
}
self.node_stats
.entry(node_id)
.and_modify(|(evaluation, _)| {
if *evaluation == NodeEvaluation::SurelyOnce
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
&& !self.conditional
{
*evaluation = NodeEvaluation::Common;
self.found_common = true;
}
})
.or_insert_with(|| {
let evaluation = if self.conditional {
NodeEvaluation::ConditionallyAtLeastOnce
} else {
NodeEvaluation::SurelyOnce
};
(evaluation, None)
});
}
self.visit_stack
.push(VisitRecord::NodeItem(node_id, is_valid));
Expand Down Expand Up @@ -383,9 +401,8 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter

// Handle nodes with identifiers only
if let Some(node_id) = node_id {
let (count, conditional_count, common_index) =
self.node_stats.get_mut(&node_id).unwrap();
if *count > 1 || *count == 1 && *conditional_count > 0 {
let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap();
if *evaluation == NodeEvaluation::Common {
// step index to skip all sub-node (which has smaller series number).
while self.down_index < self.id_array.len()
&& self.id_array[self.down_index].0 < up_index
Expand Down
1 change: 0 additions & 1 deletion datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ impl DFSchema {
None => self_unqualified_names.contains(field.name().as_str()),
};
if !duplicated_field {
// self.inner.fields.push(field.clone());
schema_builder.push(Arc::clone(field));
qualifiers.push(qualifier.cloned());
}
Expand Down
3 changes: 1 addition & 2 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ fn hash_array_primitive<T>(
hashes_buffer: &mut [u64],
rehash: bool,
) where
T: ArrowPrimitiveType,
<T as arrow_array::ArrowPrimitiveType>::Native: HashValue,
T: ArrowPrimitiveType<Native: HashValue>,
{
assert_eq!(
hashes_buffer.len(),
Expand Down
24 changes: 4 additions & 20 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ pub mod proxy;
pub mod string_utils;

use crate::error::{_internal_datafusion_err, _internal_err};
use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
use arrow::array::{ArrayRef, PrimitiveArray};
use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::ArrayRef;
use arrow::buffer::OffsetBuffer;
use arrow::compute::{partition, take_arrays, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef};
use arrow_array::cast::AsArray;
use arrow_array::{
Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait,
RecordBatchOptions,
};
use arrow_schema::DataType;
use sqlparser::ast::Ident;
Expand Down Expand Up @@ -92,20 +90,6 @@ pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result<Vec<ScalarValu
.collect()
}

/// Construct a new RecordBatch from the rows of the `record_batch` at the `indices`.
pub fn get_record_batch_at_indices(
record_batch: &RecordBatch,
indices: &PrimitiveArray<UInt32Type>,
) -> Result<RecordBatch> {
let new_columns = take_arrays(record_batch.columns(), indices, None)?;
RecordBatch::try_new_with_options(
record_batch.schema(),
new_columns,
&RecordBatchOptions::new().with_row_count(Some(indices.len())),
)
.map_err(|e| arrow_datafusion_err!(e))
}

/// This function compares two tuples depending on the given sort options.
pub fn compare_rows(
x: &[ScalarValue],
Expand Down
6 changes: 2 additions & 4 deletions datafusion/core/benches/sql_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,8 @@ fn criterion_benchmark(c: &mut Criterion) {

let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas());

// 10, 35: Physical plan does not support logical expression Exists(<subquery>)
// 45: Physical plan does not support logical expression (<subquery>)
// 41: Optimizing disjunctions not supported
let ignored = [10, 35, 41, 45];
// 41: check_analyzed_plan: Correlated column is not allowed in predicate
let ignored = [41];

let raw_tpcds_sql_queries = (1..100)
.filter(|q| !ignored.contains(q))
Expand Down
4 changes: 0 additions & 4 deletions datafusion/core/src/catalog_common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@ pub use datafusion_sql::{ResolvedTableReference, TableReference};
use std::collections::BTreeSet;
use std::ops::ControlFlow;

/// See [`CatalogProviderList`]
#[deprecated(since = "35.0.0", note = "use [`CatalogProviderList`] instead")]
pub trait CatalogList: CatalogProviderList {}

/// Collects all tables and views referenced in the SQL statement. CTEs are collected separately.
/// This can be used to determine which tables need to be in the catalog for a query to be planned.
///
Expand Down
25 changes: 1 addition & 24 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,32 +373,9 @@ impl DataFrame {
self.select(expr)
}

/// Expand each list element of a column to multiple rows.
#[deprecated(since = "37.0.0", note = "use unnest_columns instead")]
pub fn unnest_column(self, column: &str) -> Result<DataFrame> {
self.unnest_columns(&[column])
}

/// Expand each list element of a column to multiple rows, with
/// behavior controlled by [`UnnestOptions`].
///
/// Please see the documentation on [`UnnestOptions`] for more
/// details about the meaning of unnest.
#[deprecated(since = "37.0.0", note = "use unnest_columns_with_options instead")]
pub fn unnest_column_with_options(
self,
column: &str,
options: UnnestOptions,
) -> Result<DataFrame> {
self.unnest_columns_with_options(&[column], options)
}

/// Expand multiple list/struct columns into a set of rows and new columns.
///
/// See also:
///
/// 1. [`UnnestOptions`] documentation for the behavior of `unnest`
/// 2. [`Self::unnest_column_with_options`]
/// See also: [`UnnestOptions`] documentation for the behavior of `unnest`
///
/// # Example
/// ```
Expand Down
25 changes: 18 additions & 7 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ use crate::error::{DataFusionError, Result};
use crate::execution::context::{ExecutionProps, SessionState};
use crate::logical_expr::utils::generate_sort_key;
use crate::logical_expr::{
Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window,
Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window,
};
use crate::logical_expr::{
Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition,
UserDefinedLogicalNode,
};
use crate::logical_expr::{Limit, Values};
use crate::physical_expr::{create_physical_expr, create_physical_exprs};
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use crate::physical_plan::analyze::AnalyzeExec;
Expand Down Expand Up @@ -78,8 +77,8 @@ use datafusion_expr::expr::{
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery,
SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::Literal;
Expand Down Expand Up @@ -796,8 +795,20 @@ impl DefaultPhysicalPlanner {
}
LogicalPlan::Subquery(_) => todo!(),
LogicalPlan::SubqueryAlias(_) => children.one()?,
LogicalPlan::Limit(Limit { skip, fetch, .. }) => {
LogicalPlan::Limit(limit) => {
let input = children.one()?;
let SkipType::Literal(skip) = limit.get_skip_type()? else {
return not_impl_err!(
"Unsupported OFFSET expression: {:?}",
limit.skip
);
};
let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
return not_impl_err!(
"Unsupported LIMIT expression: {:?}",
limit.fetch
);
};

// GlobalLimitExec requires a single partition for input
let input = if input.output_partitioning().partition_count() == 1 {
Expand All @@ -806,13 +817,13 @@ impl DefaultPhysicalPlanner {
// Apply a LocalLimitExec to each partition. The optimizer will also insert
// a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec
if let Some(fetch) = fetch {
Arc::new(LocalLimitExec::new(input, *fetch + skip))
Arc::new(LocalLimitExec::new(input, fetch + skip))
} else {
input
}
};

Arc::new(GlobalLimitExec::new(input, *skip, *fetch))
Arc::new(GlobalLimitExec::new(input, skip, fetch))
}
LogicalPlan::Unnest(Unnest {
list_type_columns,
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/core_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ mod dataframe;
/// Run all tests that are found in the `macro_hygiene` directory
mod macro_hygiene;

/// Run all tests that are found in the `execution` directory
mod execution;

/// Run all tests that are found in the `expr_api` directory
mod expr_api;

Expand Down
95 changes: 95 additions & 0 deletions datafusion/core/tests/execution/logical_plan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::Int64Array;
use arrow_schema::{DataType, Field};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
use datafusion_execution::TaskContext;
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::logical_plan::{LogicalPlan, Values};
use datafusion_expr::{Aggregate, AggregateUDF, Expr};
use datafusion_functions_aggregate::count::Count;
use datafusion_physical_plan::collect;
use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::Deref;
use std::sync::Arc;

///! Logical plans need to provide stable semantics, as downstream projects
///! create them and depend on them. Test executable semantics of logical plans.
#[tokio::test]
async fn count_only_nulls() -> Result<()> {
// Input: VALUES (NULL), (NULL), (NULL) AS _(col)
let input_schema = Arc::new(DFSchema::from_unqualified_fields(
vec![Field::new("col", DataType::Null, true)].into(),
HashMap::new(),
)?);
let input = Arc::new(LogicalPlan::Values(Values {
schema: input_schema,
values: vec![
vec![Expr::Literal(ScalarValue::Null)],
vec![Expr::Literal(ScalarValue::Null)],
vec![Expr::Literal(ScalarValue::Null)],
],
}));
let input_col_ref = Expr::Column(Column {
relation: None,
name: "col".to_string(),
});

// Aggregation: count(col) AS count
let aggregate = LogicalPlan::Aggregate(Aggregate::try_new(
input,
vec![],
vec![Expr::AggregateFunction(AggregateFunction {
func: Arc::new(AggregateUDF::new_from_impl(Count::new())),
args: vec![input_col_ref],
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
})],
)?);

// Execute and verify results
let session_state = SessionStateBuilder::new().build();
let physical_plan = session_state.create_physical_plan(&aggregate).await?;
let result =
collect(physical_plan, Arc::new(TaskContext::from(&session_state))).await?;

let result = only(result.as_slice());
let result_schema = result.schema();
let field = only(result_schema.fields().deref());
let column = only(result.columns());

assert_eq!(field.data_type(), &DataType::Int64); // TODO should be UInt64
assert_eq!(column.deref(), &Int64Array::from(vec![0]));

Ok(())
}

fn only<T>(elements: &[T]) -> &T
where
T: Debug,
{
let [element] = elements else {
panic!("Expected exactly one element, got {:?}", elements);
};
element
}
18 changes: 18 additions & 0 deletions datafusion/core/tests/execution/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod logical_plan;
4 changes: 2 additions & 2 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ fn simplify_scan_predicate() -> Result<()> {
.build()?;

// before simplify: t.g = power(t.f, 1.0)
// after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]";
// after simplify: t.g = t.f"
let expected = "TableScan: test, full_filters=[g = f]";
let actual = get_optimized_plan_formatted(plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
Expand Down
Loading

0 comments on commit 984c6ee

Please sign in to comment.