Skip to content

Commit

Permalink
Refactor SqlToRel::sql_expr_to_logical_expr_internal to reduce stac…
Browse files Browse the repository at this point in the history
…k size (apache#12384)

* Refactor sql_expr_to_logical_expr_internal to reduce stack size

* Pass Expr by value instead of via Box

* Update datafusion/sql/src/expr/mod.rs

Co-authored-by: Andrew Lamb <[email protected]>

* Cargo fmt

* Formatting

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
Jefffrey and alamb authored Sep 9, 2024
1 parent 4569cbb commit 79b3433
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 114 deletions.
232 changes: 120 additions & 112 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use datafusion_expr::planner::{
PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr,
};
use sqlparser::ast::{
BinaryOperator, CastKind, DictionaryField, Expr as SQLExpr, MapEntry, StructField,
Subscript, TrimWhereField, Value,
BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField,
Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value,
};

use datafusion_common::{
Expand Down Expand Up @@ -174,6 +174,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
// NOTE: This function is called recusively, so each match arm body should be as
// small as possible to avoid stack overflows in debug builds. Follow the
// common pattern of extracting into a separate function for non-trivial
// arms. See https://github.com/apache/datafusion/pull/12384 for more context.
match sql {
SQLExpr::Value(value) => {
self.parse_value(value, planner_context.prepare_param_data_types())
Expand Down Expand Up @@ -210,91 +214,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

// <expr>["foo"], <expr>[4] or <expr>[4:5]
SQLExpr::Subscript { expr, subscript } => {
let expr =
self.sql_expr_to_logical_expr(*expr, schema, planner_context)?;

let field_access = match *subscript {
Subscript::Index { index } => {
// index can be a name, in which case it is a named field access
match index {
SQLExpr::Value(
Value::SingleQuotedString(s)
| Value::DoubleQuotedString(s),
) => GetFieldAccess::NamedStructField {
name: ScalarValue::from(s),
},
SQLExpr::JsonAccess { .. } => {
return not_impl_err!("JsonAccess");
}
// otherwise treat like a list index
_ => GetFieldAccess::ListIndex {
key: Box::new(self.sql_expr_to_logical_expr(
index,
schema,
planner_context,
)?),
},
}
}
Subscript::Slice {
lower_bound,
upper_bound,
stride,
} => {
// Means access like [:2]
let lower_bound = if let Some(lower_bound) = lower_bound {
self.sql_expr_to_logical_expr(
lower_bound,
schema,
planner_context,
)
} else {
not_impl_err!("Slice subscript requires a lower bound")
}?;

// means access like [2:]
let upper_bound = if let Some(upper_bound) = upper_bound {
self.sql_expr_to_logical_expr(
upper_bound,
schema,
planner_context,
)
} else {
not_impl_err!("Slice subscript requires an upper bound")
}?;

// stride, default to 1
let stride = if let Some(stride) = stride {
self.sql_expr_to_logical_expr(
stride,
schema,
planner_context,
)?
} else {
lit(1i64)
};

GetFieldAccess::ListRange {
start: Box::new(lower_bound),
stop: Box::new(upper_bound),
stride: Box::new(stride),
}
}
};

let mut field_access_expr = RawFieldAccessExpr { expr, field_access };
for planner in self.context_provider.get_expr_planners() {
match planner.plan_field_access(field_access_expr, schema)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(expr) => {
field_access_expr = expr;
}
}
}

not_impl_err!(
"GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}"
)
self.sql_subscript_to_expr(*expr, subscript, schema, planner_context)
}

SQLExpr::CompoundIdentifier(ids) => {
Expand All @@ -320,31 +240,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
expr,
data_type,
format,
} => {
if let Some(format) = format {
return not_impl_err!("CAST with format is not supported: {format}");
}

let dt = self.convert_data_type(&data_type)?;
let expr =
self.sql_expr_to_logical_expr(*expr, schema, planner_context)?;

// numeric constants are treated as seconds (rather as nanoseconds)
// to align with postgres / duckdb semantics
let expr = match &dt {
DataType::Timestamp(TimeUnit::Nanosecond, tz)
if expr.get_type(schema)? == DataType::Int64 =>
{
Expr::Cast(Cast::new(
Box::new(expr),
DataType::Timestamp(TimeUnit::Second, tz.clone()),
))
}
_ => expr,
};

Ok(Expr::Cast(Cast::new(Box::new(expr), dt)))
}
} => self.sql_cast_to_expr(*expr, data_type, format, schema, planner_context),

SQLExpr::Cast {
kind: CastKind::TryCast | CastKind::SafeCast,
Expand Down Expand Up @@ -1016,6 +912,118 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
not_impl_err!("Overlay not supported by ExprPlanner: {overlay_args:?}")
}

fn sql_cast_to_expr(
&self,
expr: SQLExpr,
data_type: SQLDataType,
format: Option<CastFormat>,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
if let Some(format) = format {
return not_impl_err!("CAST with format is not supported: {format}");
}

let dt = self.convert_data_type(&data_type)?;
let expr = self.sql_expr_to_logical_expr(expr, schema, planner_context)?;

// numeric constants are treated as seconds (rather as nanoseconds)
// to align with postgres / duckdb semantics
let expr = match &dt {
DataType::Timestamp(TimeUnit::Nanosecond, tz)
if expr.get_type(schema)? == DataType::Int64 =>
{
Expr::Cast(Cast::new(
Box::new(expr),
DataType::Timestamp(TimeUnit::Second, tz.clone()),
))
}
_ => expr,
};

Ok(Expr::Cast(Cast::new(Box::new(expr), dt)))
}

fn sql_subscript_to_expr(
&self,
expr: SQLExpr,
subscript: Box<Subscript>,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let expr = self.sql_expr_to_logical_expr(expr, schema, planner_context)?;

let field_access = match *subscript {
Subscript::Index { index } => {
// index can be a name, in which case it is a named field access
match index {
SQLExpr::Value(
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
) => GetFieldAccess::NamedStructField {
name: ScalarValue::from(s),
},
SQLExpr::JsonAccess { .. } => {
return not_impl_err!("JsonAccess");
}
// otherwise treat like a list index
_ => GetFieldAccess::ListIndex {
key: Box::new(self.sql_expr_to_logical_expr(
index,
schema,
planner_context,
)?),
},
}
}
Subscript::Slice {
lower_bound,
upper_bound,
stride,
} => {
// Means access like [:2]
let lower_bound = if let Some(lower_bound) = lower_bound {
self.sql_expr_to_logical_expr(lower_bound, schema, planner_context)
} else {
not_impl_err!("Slice subscript requires a lower bound")
}?;

// means access like [2:]
let upper_bound = if let Some(upper_bound) = upper_bound {
self.sql_expr_to_logical_expr(upper_bound, schema, planner_context)
} else {
not_impl_err!("Slice subscript requires an upper bound")
}?;

// stride, default to 1
let stride = if let Some(stride) = stride {
self.sql_expr_to_logical_expr(stride, schema, planner_context)?
} else {
lit(1i64)
};

GetFieldAccess::ListRange {
start: Box::new(lower_bound),
stop: Box::new(upper_bound),
stride: Box::new(stride),
}
}
};

let mut field_access_expr = RawFieldAccessExpr { expr, field_access };
for planner in self.context_provider.get_expr_planners() {
match planner.plan_field_access(field_access_expr, schema)? {
PlannerResult::Planned(expr) => return Ok(expr),
PlannerResult::Original(expr) => {
field_access_expr = expr;
}
}
}

not_impl_err!(
"GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}"
)
}
}

#[cfg(test)]
Expand Down
2 changes: 0 additions & 2 deletions datafusion/sqllogictest/bin/sqllogictests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ use datafusion_common_runtime::SpawnedTask;

const TEST_DIRECTORY: &str = "test_files/";
const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_";
const STACK_SIZE: usize = 2 * 1024 * 1024 + 512 * 1024; // 2.5 MBs, the default 2 MBs is currently too small

pub fn main() -> Result<()> {
tokio::runtime::Builder::new_multi_thread()
.thread_stack_size(STACK_SIZE)
.enable_all()
.build()
.unwrap()
Expand Down

0 comments on commit 79b3433

Please sign in to comment.