Skip to content

Commit

Permalink
Introduce FunctionRegistry dependency to optimize and rewrite rule (a…
Browse files Browse the repository at this point in the history
…pache#10714)

* mv function registry to expr

Signed-off-by: jayzhan211 <[email protected]>

* registry move to config trait

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* rm dependency

Signed-off-by: jayzhan211 <[email protected]>

* fix cli cargo lock

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored and findepi committed Jul 16, 2024
1 parent 215fb4b commit 90aca5c
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 140 deletions.
160 changes: 79 additions & 81 deletions datafusion-cli/Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2350,6 +2350,10 @@ impl OptimizerConfig for SessionState {
fn options(&self) -> &ConfigOptions {
self.config_options()
}

fn function_registry(&self) -> Option<&dyn FunctionRegistry> {
Some(self)
}
}

/// Create a new task context instance from SessionContext
Expand Down
7 changes: 6 additions & 1 deletion datafusion/execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@ pub mod config;
pub mod disk_manager;
pub mod memory_pool;
pub mod object_store;
pub mod registry;
pub mod runtime_env;
mod stream;
mod task;

pub mod registry {
pub use datafusion_expr::registry::{
FunctionRegistry, MemoryFunctionRegistry, SerializerRegistry,
};
}

pub use disk_manager::DiskManager;
pub use registry::FunctionRegistry;
pub use stream::{RecordBatchStream, SendableRecordBatchStream};
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub mod function;
pub mod groups_accumulator;
pub mod interval_arithmetic;
pub mod logical_plan;
pub mod registry;
pub mod simplify;
pub mod sort_properties;
pub mod tree_node;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

//! FunctionRegistry trait
use crate::expr_rewriter::FunctionRewrite;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, Result};
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::collections::HashMap;
use std::{collections::HashSet, sync::Arc};

Expand Down
1 change: 0 additions & 1 deletion datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ async-trait = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-physical-expr = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::collections::HashSet;
use std::sync::Arc;

use chrono::{DateTime, Utc};
use datafusion_expr::registry::FunctionRegistry;
use log::{debug, warn};

use datafusion_common::alias::AliasGenerator;
Expand Down Expand Up @@ -122,6 +123,10 @@ pub trait OptimizerConfig {
fn alias_generator(&self) -> Arc<AliasGenerator>;

fn options(&self) -> &ConfigOptions;

fn function_registry(&self) -> Option<&dyn FunctionRegistry> {
None
}
}

/// A standalone [`OptimizerConfig`] that can be used independently
Expand Down
69 changes: 14 additions & 55 deletions datafusion/optimizer/src/replace_distinct_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Column, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr_rewriter::normalize_cols;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{col, LogicalPlanBuilder};
use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan};
use datafusion_functions_aggregate::first_last::first_value;

/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
///
Expand Down Expand Up @@ -73,7 +73,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Distinct(Distinct::All(input)) => {
Expand All @@ -95,9 +95,18 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
let expr_cnt = on_expr.len();

// Construct the aggregation expression to be used to fetch the selected expressions.
let aggr_expr = select_expr
.into_iter()
.map(|e| first_value(vec![e], false, None, sort_expr.clone(), None));
let first_value_udaf =
config.function_registry().unwrap().udaf("first_value")?;
let aggr_expr = select_expr.into_iter().map(|e| {
Expr::AggregateFunction(AggregateFunction::new_udf(
first_value_udaf.clone(),
vec![e],
false,
None,
sort_expr.clone(),
None,
))
});

let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?;
let group_expr = normalize_cols(on_expr, input.as_ref())?;
Expand Down Expand Up @@ -163,53 +172,3 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
Some(BottomUp)
}
}

#[cfg(test)]
mod tests {
use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::test::{assert_optimized_plan_eq, test_table_scan};
use datafusion_expr::{col, LogicalPlanBuilder};
use std::sync::Arc;

#[test]
fn replace_distinct() -> datafusion_common::Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.distinct()?
.build()?;

let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\
\n Projection: test.a, test.b\
\n TableScan: test";

assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
plan,
expected,
)
}

#[test]
fn replace_distinct_on() -> datafusion_common::Result<()> {
let table_scan = test_table_scan().unwrap();
let plan = LogicalPlanBuilder::from(table_scan)
.distinct_on(
vec![col("a")],
vec![col("b")],
Some(vec![col("a").sort(false, true), col("c").sort(true, false)]),
)?
.build()?;

let expected = "Projection: first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\
\n Sort: test.a DESC NULLS FIRST\
\n Aggregate: groupBy=[[test.a]], aggr=[[first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\
\n TableScan: test";

assert_optimized_plan_eq(
Arc::new(ReplaceDistinctWithAggregate::new()),
plan,
expected,
)
}
}
36 changes: 36 additions & 0 deletions datafusion/sqllogictest/test_files/distinct_on.slt
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,39 @@ LIMIT 3;
-25 15295
45 15673
-72 -11122

# test distinct on
statement ok
create table t(a int, b int, c int) as values (1, 2, 3);

statement ok
set datafusion.explain.logical_plan_only = true;

query TT
explain select distinct on (a) b from t order by a desc, c;
----
logical_plan
01)Projection: first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST] AS b
02)--Sort: t.a DESC NULLS FIRST
03)----Aggregate: groupBy=[[t.a]], aggr=[[first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST]]]
04)------TableScan: t projection=[a, b, c]

statement ok
drop table t;

# test distinct
statement ok
create table t(a int, b int) as values (1, 2);

statement ok
set datafusion.explain.logical_plan_only = true;

query TT
explain select distinct a, b from t;
----
logical_plan
01)Aggregate: groupBy=[[t.a, t.b]], aggr=[[]]
02)--TableScan: t projection=[a, b]

statement ok
drop table t;

0 comments on commit 90aca5c

Please sign in to comment.