Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short term way to make AggregateStatistics still work when min/max is converted to udaf #11261

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 85 additions & 51 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,29 @@ fn take_optimizable_column_and_table_count(
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
let col_stats = &stats.column_statistics;
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() {
if let Precision::Exact(num_rows) = stats.num_rows {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
let current_val = &col_stats[col_expr.index()].null_count;
if let &Precision::Exact(val) = current_val {
return Some((
ScalarValue::Int64(Some((num_rows - val) as i64)),
agg_expr.name().to_string(),
));
}
} else if let Some(lit_expr) =
exprs[0].as_any().downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
agg_expr.name().to_string(),
));
}
if is_non_distinct_count(agg_expr) {
if let Precision::Exact(num_rows) = stats.num_rows {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
let current_val = &col_stats[col_expr.index()].null_count;
if let &Precision::Exact(val) = current_val {
return Some((
ScalarValue::Int64(Some((num_rows - val) as i64)),
agg_expr.name().to_string(),
));
}
} else if let Some(lit_expr) =
exprs[0].as_any().downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
agg_expr.name().to_string(),
));
}
}
}
Expand All @@ -182,34 +180,30 @@ fn take_optimizable_min(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Min>()
{
if is_min(agg_expr) {
if let Ok(min_data_type) =
ScalarValue::try_from(casted_expr.field().unwrap().data_type())
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
return Some((min_data_type, casted_expr.name().to_string()));
return Some((min_data_type, agg_expr.name().to_string()));
}
}
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Min>()
{
if casted_expr.expressions().len() == 1 {
if is_min(agg_expr) {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
if let Precision::Exact(val) =
&col_stats[col_expr.index()].min_value
{
if !val.is_null() {
return Some((
val.clone(),
casted_expr.name().to_string(),
agg_expr.name().to_string(),
));
}
}
Expand All @@ -232,34 +226,30 @@ fn take_optimizable_max(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Max>()
{
if is_max(agg_expr) {
if let Ok(max_data_type) =
ScalarValue::try_from(casted_expr.field().unwrap().data_type())
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
return Some((max_data_type, casted_expr.name().to_string()));
return Some((max_data_type, agg_expr.name().to_string()));
}
}
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Max>()
{
if casted_expr.expressions().len() == 1 {
if is_max(agg_expr) {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
if let Precision::Exact(val) =
&col_stats[col_expr.index()].max_value
{
if !val.is_null() {
return Some((
val.clone(),
casted_expr.name().to_string(),
agg_expr.name().to_string(),
));
}
}
Expand All @@ -273,6 +263,50 @@ fn take_optimizable_max(
None
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool {
alamb marked this conversation as resolved.
Show resolved Hide resolved
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_min(agg_expr: &dyn AggregateExpr) -> bool {
alamb marked this conversation as resolved.
Show resolved Hide resolved
if agg_expr.as_any().is::<expressions::Min>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "min" {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_max(agg_expr: &dyn AggregateExpr) -> bool {
alamb marked this conversation as resolved.
Show resolved Hide resolved
if agg_expr.as_any().is::<expressions::Max>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "max" {
return true;
}
}

false
}

#[cfg(test)]
pub(crate) mod tests {
use super::*;
Expand Down