Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: volatile expressions should not be target of common subexpt elimination #8520

Merged
merged 5 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,24 @@ impl ScalarFunctionDefinition {
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
}
}

/// Whether this function is volatile, i.e. whether it can return different results
/// when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
match self {
ScalarFunctionDefinition::BuiltIn(fun) => {
Ok(fun.volatility() == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::UDF(udf) => {
Ok(udf.signature().volatility == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::Name(func) => {
internal_err!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"Cannot determine volatility of unresolved function: {func}"
)
}
}
}
}

impl ScalarFunction {
Expand Down Expand Up @@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
.join(", "))
}

/// Whether the given expression is volatile, i.e. whether it can return different results
viirya marked this conversation as resolved.
Show resolved Hide resolved
/// when evaluated multiple times with the same input.
pub fn is_volatile(expr: &Expr) -> Result<bool> {
match expr {
Expr::ScalarFunction(func) => func.func_def.is_volatile(),
_ => Ok(false),
}
}

#[cfg(test)]
mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{case, lit, Expr};
use crate::{
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction,
ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature,
Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
use std::sync::Arc;

#[test]
fn format_case_when() -> Result<()> {
Expand Down Expand Up @@ -1800,4 +1832,45 @@ mod test {
"UInt32(1) OR UInt32(2)"
);
}

#[test]
fn test_is_volatile_scalar_func_definition() {
// BuiltIn
assert!(
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
.is_volatile()
.unwrap()
);
assert!(
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
.is_volatile()
.unwrap()
);

// UDF
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation =
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
let udf = Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
&return_type,
&fun,
));
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());

let udf = Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile),
&return_type,
&fun,
));
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());

// Unresolved function
ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
.is_volatile()
.expect_err("Shouldn't determine volatility of unresolved function");
}
}
18 changes: 11 additions & 7 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr::{is_volatile, Alias};
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
Expand Down Expand Up @@ -113,6 +113,8 @@ impl CommonSubexprEliminate {
let Projection { expr, input, .. } = projection;
let input_schema = Arc::clone(input.schema());
let mut expr_set = ExprSet::new();

// Visit expr list and build expr identifier to occuring count map (`expr_set`).
let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;

let (mut new_expr, new_input) =
Expand Down Expand Up @@ -516,7 +518,7 @@ enum ExprMask {
}

impl ExprMask {
fn ignores(&self, expr: &Expr) -> bool {
fn ignores(&self, expr: &Expr) -> Result<bool> {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
Expand All @@ -527,12 +529,14 @@ impl ExprMask {
| Expr::Wildcard { .. }
);

let is_volatile = is_volatile(expr)?;

let is_aggr = matches!(expr, Expr::AggregateFunction(..));

match self {
Self::Normal => is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_normal_minus_aggregates,
}
Ok(match self {
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
})
}
}

Expand Down Expand Up @@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

let (idx, sub_expr_desc) = self.pop_enter_mark();
// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
if self.expr_mask.ignores(expr)? {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,9 @@ query ?
SELECT find_in_set(NULL, NULL)
----
NULL

# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
query B
viirya marked this conversation as resolved.
Show resolved Hide resolved
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

This test fails on main:

❯ SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
;
+------+
| r    |
+------+
| true |
+------+
1 row in set. Query took 0.037 seconds.

----
false
viirya marked this conversation as resolved.
Show resolved Hide resolved