Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: add a SQL IR and factor out optimizations. #80

Merged
merged 3 commits into from
Apr 27, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arroyo-api/src/optimizations.rs
Original file line number Diff line number Diff line change
@@ -339,7 +339,7 @@ impl FusedExpressionOperatorBuilder {
}
Some(Record) => {
self.body.push(quote!(
let record:#out_type = #expression;));
let record:#out_type = #expression?;));
self.current_return_type = Some(OptionalRecord);
}
Some(OptionalRecord) => {
2 changes: 1 addition & 1 deletion arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
@@ -120,7 +120,7 @@ impl Debug for WindowType {
}
}

#[derive(Clone, Encode, Decode, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Encode, Decode, Serialize, Deserialize, PartialEq, Eq)]
pub enum WatermarkType {
Periodic {
period: Duration,
1 change: 0 additions & 1 deletion arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(test)]
mod tests {
use arroyo_sql_macro::single_test_codegen;
use chrono;

// Casts
single_test_codegen!(
111 changes: 83 additions & 28 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;

use crate::{
operators::TwoPhaseAggregation,
pipeline::SortDirection,
types::{StructDef, StructField, TypeDef},
};
@@ -28,7 +29,7 @@ pub trait ExpressionGenerator: Debug {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Expression {
Column(ColumnExpression),
UnaryBoolean(UnaryBooleanExpression),
@@ -92,7 +93,7 @@ impl ExpressionGenerator for Expression {
}

impl Expression {
pub(crate) fn has_max_value(&self, field: &StructField) -> Option<i64> {
pub(crate) fn has_max_value(&self, field: &StructField) -> Option<u64> {
match self {
Expression::BinaryComparison(BinaryComparisonExpression { left, op, right }) => {
if let BinaryComparison::And = op {
@@ -116,13 +117,15 @@ impl Expression {
) => {
if field == column_field {
match (op, literal) {
(BinaryComparison::Lt, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::Lt, ScalarValue::UInt64(Some(max))) => {
Some(*max - 1)
}
(BinaryComparison::LtEq, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::LtEq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::Int64(Some(max))) => Some(*max),
_ => None,
}
} else {
@@ -135,13 +138,15 @@ impl Expression {
) => {
if field == column_field {
match (op, literal) {
(BinaryComparison::Gt, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::Gt, ScalarValue::UInt64(Some(max))) => {
Some(*max + 1)
}
(BinaryComparison::GtEq, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::GtEq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::Int64(Some(max))) => Some(*max),
_ => None,
}
} else {
@@ -415,7 +420,7 @@ impl Column {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ColumnExpression {
column_field: StructField,
}
@@ -460,7 +465,7 @@ pub enum UnaryOperator {
Negative,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UnaryBooleanExpression {
operator: UnaryOperator,
input: Box<Expression>,
@@ -512,7 +517,7 @@ impl UnaryBooleanExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct LiteralExpression {
literal: ScalarValue,
}
@@ -533,7 +538,7 @@ impl LiteralExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BinaryComparison {
Eq,
NotEq,
@@ -568,7 +573,7 @@ impl TryFrom<datafusion_expr::Operator> for BinaryComparison {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BinaryComparisonExpression {
pub left: Box<Expression>,
pub op: BinaryComparison,
@@ -633,7 +638,7 @@ impl ExpressionGenerator for BinaryComparisonExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BinaryMathOperator {
Plus,
Minus,
@@ -670,7 +675,7 @@ impl TryFrom<datafusion_expr::Operator> for BinaryMathOperator {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BinaryMathExpression {
left: Box<Expression>,
op: BinaryMathOperator,
@@ -718,7 +723,7 @@ impl ExpressionGenerator for BinaryMathExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StructFieldExpression {
struct_expression: Box<Expression>,
struct_field: StructField,
@@ -815,10 +820,28 @@ impl Aggregator {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct AggregationExpression {
producing_expression: Box<Expression>,
aggregator: Aggregator,
pub producing_expression: Box<Expression>,
pub aggregator: Aggregator,
}

impl TryFrom<AggregationExpression> for TwoPhaseAggregation {
type Error = anyhow::Error;

fn try_from(aggregation_expression: AggregationExpression) -> Result<Self> {
if aggregation_expression.allows_two_phase() {
Ok(TwoPhaseAggregation {
incoming_expression: *aggregation_expression.producing_expression,
aggregator: aggregation_expression.aggregator,
})
} else {
bail!(
"{:?} does not support two phase aggregation",
aggregation_expression.aggregator
);
}
}
}

impl AggregationExpression {
@@ -833,6 +856,40 @@ impl AggregationExpression {
aggregator,
}))
}

pub(crate) fn allows_two_phase(&self) -> bool {
match self.aggregator {
Aggregator::Count
| Aggregator::Sum
| Aggregator::Min
| Aggregator::Avg
| Aggregator::Max => true,
Aggregator::CountDistinct => false,
}
}

pub fn try_from_expression(expr: &Expr, input_struct: &StructDef) -> Result<Self> {
match expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
fun,
args,
distinct,
filter: None,
}) => {
if args.len() != 1 {
bail!("unexpected arg length");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an error that would be sent back to the user? or does it imply a bug in our code? if it's the latter, then it should probably be a panic instead.

}
let producing_expression =
Box::new(to_expression_generator(&args[0], input_struct)?);
let aggregator = Aggregator::from_datafusion(fun.clone(), *distinct)?;
Ok(AggregationExpression {
producing_expression,
aggregator,
})
}
_ => bail!("expected aggregate function, not {}", expr),
}
}
}

impl ExpressionGenerator for AggregationExpression {
@@ -903,7 +960,7 @@ impl ExpressionGenerator for AggregationExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CastExpression {
input: Box<Expression>,
data_type: DataType,
@@ -937,10 +994,8 @@ impl CastExpression {
{
true
// handle date to string casts.
} else if Self::is_date(input_data_type) || Self::is_string(output_data_type) {
true
} else {
false
Self::is_date(input_data_type) || Self::is_string(output_data_type)
}
}

@@ -1084,7 +1139,7 @@ impl TryFrom<BuiltinScalarFunction> for NumericFunction {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NumericExpression {
function: NumericFunction,
input: Box<Expression>,
@@ -1113,7 +1168,7 @@ impl ExpressionGenerator for NumericExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct SortExpression {
value: Expression,
direction: SortDirection,
@@ -1173,7 +1228,7 @@ impl SortExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum StringFunction {
Ascii(Box<Expression>),
BitLength(Box<Expression>),
@@ -1211,7 +1266,7 @@ pub enum StringFunction {
Rtrim(Box<Expression>, Option<Box<Expression>>),
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum HashFunction {
MD5,
SHA224,
@@ -1247,7 +1302,7 @@ impl TryFrom<BuiltinScalarFunction> for HashFunction {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct HashExpression {
function: HashFunction,
input: Box<Expression>,
6 changes: 4 additions & 2 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -11,7 +11,9 @@ use datafusion::physical_plan::functions::make_scalar_function;

mod expressions;
mod operators;
mod optimizations;
mod pipeline;
mod plan_graph;
pub mod schemas;
pub mod types;

@@ -435,10 +437,10 @@ pub fn get_test_expression(
let statement = &ast[0];
let sql_to_rel = SqlToRel::new(&schema_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
let mut optimizer_config = OptimizerContext::default();
let optimizer_config = OptimizerContext::default();
let optimizer = Optimizer::new();
let plan = optimizer
.optimize(&plan, &mut optimizer_config, |_plan, _rule| {})
.optimize(&plan, &optimizer_config, |_plan, _rule| {})
.unwrap();
let LogicalPlan::Projection(projection) = plan else {panic!("expect projection")};
let generating_expression = to_expression_generator(&projection.expr[0], &struct_def).unwrap();
Loading