Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: add a SQL IR and factor out optimizations. #80

Merged
merged 3 commits into from
Apr 27, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arroyo-api/src/optimizations.rs
Original file line number Diff line number Diff line change
@@ -339,7 +339,7 @@ impl FusedExpressionOperatorBuilder {
}
Some(Record) => {
self.body.push(quote!(
let record:#out_type = #expression;));
let record:#out_type = #expression?;));
self.current_return_type = Some(OptionalRecord);
}
Some(OptionalRecord) => {
2 changes: 1 addition & 1 deletion arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
@@ -120,7 +120,7 @@ impl Debug for WindowType {
}
}

#[derive(Clone, Encode, Decode, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Debug, Clone, Encode, Decode, Serialize, Deserialize, PartialEq, Eq)]
pub enum WatermarkType {
Periodic {
period: Duration,
1 change: 0 additions & 1 deletion arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(test)]
mod tests {
use arroyo_sql_macro::single_test_codegen;
use chrono;

// Casts
single_test_codegen!(
111 changes: 83 additions & 28 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;

use crate::{
operators::TwoPhaseAggregation,
pipeline::SortDirection,
types::{StructDef, StructField, TypeDef},
};
@@ -28,7 +29,7 @@ pub trait ExpressionGenerator: Debug {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Expression {
Column(ColumnExpression),
UnaryBoolean(UnaryBooleanExpression),
@@ -92,7 +93,7 @@ impl ExpressionGenerator for Expression {
}

impl Expression {
pub(crate) fn has_max_value(&self, field: &StructField) -> Option<i64> {
pub(crate) fn has_max_value(&self, field: &StructField) -> Option<u64> {
match self {
Expression::BinaryComparison(BinaryComparisonExpression { left, op, right }) => {
if let BinaryComparison::And = op {
@@ -116,13 +117,15 @@ impl Expression {
) => {
if field == column_field {
match (op, literal) {
(BinaryComparison::Lt, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::Lt, ScalarValue::UInt64(Some(max))) => {
Some(*max - 1)
}
(BinaryComparison::LtEq, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::LtEq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::Int64(Some(max))) => Some(*max),
_ => None,
}
} else {
@@ -135,13 +138,15 @@ impl Expression {
) => {
if field == column_field {
match (op, literal) {
(BinaryComparison::Gt, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::Gt, ScalarValue::UInt64(Some(max))) => {
Some(*max + 1)
}
(BinaryComparison::GtEq, ScalarValue::Int64(Some(max))) => {
(BinaryComparison::GtEq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::UInt64(Some(max))) => {
Some(*max)
}
(BinaryComparison::Eq, ScalarValue::Int64(Some(max))) => Some(*max),
_ => None,
}
} else {
@@ -415,7 +420,7 @@ impl Column {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct ColumnExpression {
column_field: StructField,
}
@@ -460,7 +465,7 @@ pub enum UnaryOperator {
Negative,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UnaryBooleanExpression {
operator: UnaryOperator,
input: Box<Expression>,
@@ -512,7 +517,7 @@ impl UnaryBooleanExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct LiteralExpression {
literal: ScalarValue,
}
@@ -533,7 +538,7 @@ impl LiteralExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BinaryComparison {
Eq,
NotEq,
@@ -568,7 +573,7 @@ impl TryFrom<datafusion_expr::Operator> for BinaryComparison {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BinaryComparisonExpression {
pub left: Box<Expression>,
pub op: BinaryComparison,
@@ -633,7 +638,7 @@ impl ExpressionGenerator for BinaryComparisonExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum BinaryMathOperator {
Plus,
Minus,
@@ -670,7 +675,7 @@ impl TryFrom<datafusion_expr::Operator> for BinaryMathOperator {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct BinaryMathExpression {
left: Box<Expression>,
op: BinaryMathOperator,
@@ -718,7 +723,7 @@ impl ExpressionGenerator for BinaryMathExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StructFieldExpression {
struct_expression: Box<Expression>,
struct_field: StructField,
@@ -815,10 +820,28 @@ impl Aggregator {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct AggregationExpression {
producing_expression: Box<Expression>,
aggregator: Aggregator,
pub producing_expression: Box<Expression>,
pub aggregator: Aggregator,
}

impl TryFrom<AggregationExpression> for TwoPhaseAggregation {
type Error = anyhow::Error;

fn try_from(aggregation_expression: AggregationExpression) -> Result<Self> {
if aggregation_expression.allows_two_phase() {
Ok(TwoPhaseAggregation {
incoming_expression: *aggregation_expression.producing_expression,
aggregator: aggregation_expression.aggregator,
})
} else {
bail!(
"{:?} does not support two phase aggregation",
aggregation_expression.aggregator
);
}
}
}

impl AggregationExpression {
@@ -833,6 +856,40 @@ impl AggregationExpression {
aggregator,
}))
}

pub(crate) fn allows_two_phase(&self) -> bool {
match self.aggregator {
Aggregator::Count
| Aggregator::Sum
| Aggregator::Min
| Aggregator::Avg
| Aggregator::Max => true,
Aggregator::CountDistinct => false,
}
}

pub fn try_from_expression(expr: &Expr, input_struct: &StructDef) -> Result<Self> {
match expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
fun,
args,
distinct,
filter: None,
}) => {
if args.len() != 1 {
bail!("unexpected arg length");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an error that would be sent back to the user? or does it imply a bug in our code? if it's the latter, then it should probably be a panic instead.

}
let producing_expression =
Box::new(to_expression_generator(&args[0], input_struct)?);
let aggregator = Aggregator::from_datafusion(fun.clone(), *distinct)?;
Ok(AggregationExpression {
producing_expression,
aggregator,
})
}
_ => bail!("expected aggregate function, not {}", expr),
}
}
}

impl ExpressionGenerator for AggregationExpression {
@@ -903,7 +960,7 @@ impl ExpressionGenerator for AggregationExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CastExpression {
input: Box<Expression>,
data_type: DataType,
@@ -937,10 +994,8 @@ impl CastExpression {
{
true
// handle date to string casts.
} else if Self::is_date(input_data_type) || Self::is_string(output_data_type) {
true
} else {
false
Self::is_date(input_data_type) || Self::is_string(output_data_type)
}
}

@@ -1084,7 +1139,7 @@ impl TryFrom<BuiltinScalarFunction> for NumericFunction {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NumericExpression {
function: NumericFunction,
input: Box<Expression>,
@@ -1113,7 +1168,7 @@ impl ExpressionGenerator for NumericExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct SortExpression {
value: Expression,
direction: SortDirection,
@@ -1173,7 +1228,7 @@ impl SortExpression {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum StringFunction {
Ascii(Box<Expression>),
BitLength(Box<Expression>),
@@ -1211,7 +1266,7 @@ pub enum StringFunction {
Rtrim(Box<Expression>, Option<Box<Expression>>),
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum HashFunction {
MD5,
SHA224,
@@ -1247,7 +1302,7 @@ impl TryFrom<BuiltinScalarFunction> for HashFunction {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct HashExpression {
function: HashFunction,
input: Box<Expression>,
6 changes: 4 additions & 2 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
@@ -11,7 +11,9 @@ use datafusion::physical_plan::functions::make_scalar_function;

mod expressions;
mod operators;
mod optimizations;
mod pipeline;
mod plan_graph;
pub mod schemas;
pub mod types;

@@ -435,10 +437,10 @@ pub fn get_test_expression(
let statement = &ast[0];
let sql_to_rel = SqlToRel::new(&schema_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
let mut optimizer_config = OptimizerContext::default();
let optimizer_config = OptimizerContext::default();
let optimizer = Optimizer::new();
let plan = optimizer
.optimize(&plan, &mut optimizer_config, |_plan, _rule| {})
.optimize(&plan, &optimizer_config, |_plan, _rule| {})
.unwrap();
let LogicalPlan::Projection(projection) = plan else {panic!("expect projection")};
let generating_expression = to_expression_generator(&projection.expr[0], &struct_def).unwrap();
54 changes: 35 additions & 19 deletions arroyo-sql/src/operators.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::time::Duration;

use crate::{
expressions::{to_expression_generator, Aggregator, Column, Expression, ExpressionGenerator},
expressions::{
to_expression_generator, AggregationExpression, Aggregator, Column, Expression,
ExpressionGenerator,
},
schemas::window_type_def,
types::{StructDef, StructField, TypeDef},
};
@@ -18,7 +21,7 @@ use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_quote, parse_str, Ident, LitInt};

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Projection {
pub field_names: Vec<Column>,
pub field_computations: Vec<Expression>,
@@ -147,11 +150,12 @@ impl ExpressionGenerator for Projection {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct AggregateProjection {
pub field_names: Vec<Column>,
pub field_computations: Vec<Expression>,
pub field_computations: Vec<AggregationExpression>,
}

impl AggregateProjection {
pub fn output_struct(&self) -> StructDef {
let fields = self
@@ -208,7 +212,7 @@ impl ExpressionGenerator for AggregateProjection {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum GroupByKind {
Basic,
WindowOutput {
@@ -219,12 +223,7 @@ pub enum GroupByKind {
}

impl GroupByKind {
pub fn output_struct(
&self,
key_projection: &Projection,
aggregate_struct: StructDef,
) -> StructDef {
let key_struct = key_projection.output_struct();
pub fn output_struct(&self, key_struct: &StructDef, aggregate_struct: &StructDef) -> StructDef {
let key_fields = key_struct.fields.len();
let aggregate_fields = aggregate_struct.fields.len();
match self {
@@ -271,11 +270,10 @@ impl GroupByKind {

pub fn to_syn_expression(
&self,
key_projection: &Projection,
aggregate_struct: StructDef,
key_struct: &StructDef,
aggregate_struct: &StructDef,
) -> syn::Expr {
let mut assignments: Vec<_> = vec![];
let key_struct = key_projection.output_struct();

key_struct.fields.iter().for_each(|field| {
let field_name: Ident = format_ident!("{}", field.field_name());
@@ -285,7 +283,7 @@ impl GroupByKind {
let field_name: Ident = format_ident!("{}", field.field_name());
assignments.push(quote!(#field_name : arg.aggregate.#field_name.clone()));
});
let return_struct = self.output_struct(key_projection, aggregate_struct);
let return_struct = self.output_struct(key_struct, aggregate_struct);
if let GroupByKind::WindowOutput {
index,
column: _,
@@ -312,12 +310,27 @@ impl GroupByKind {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct TwoPhaseAggregateProjection {
pub field_names: Vec<Column>,
pub field_computations: Vec<TwoPhaseAggregation>,
}

impl TryFrom<AggregateProjection> for TwoPhaseAggregateProjection {
type Error = anyhow::Error;

fn try_from(aggregate_projection: AggregateProjection) -> Result<Self> {
Ok(Self {
field_names: aggregate_projection.field_names,
field_computations: aggregate_projection
.field_computations
.into_iter()
.map(|computation| computation.try_into().unwrap())
.collect(),
})
}
}

impl TwoPhaseAggregateProjection {
pub fn combine_bin_syn_expr(&self) -> syn::Expr {
let some_assignments: Vec<syn::Expr> = self
@@ -544,14 +557,17 @@ impl TwoPhaseAggregateProjection {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct TwoPhaseAggregation {
incoming_expression: Expression,
aggregator: Aggregator,
pub incoming_expression: Expression,
pub aggregator: Aggregator,
}

impl TwoPhaseAggregation {
pub fn from_expression(expr: &Expr, input_struct: &StructDef) -> Result<TwoPhaseAggregation> {
if !input_struct.fields.is_empty() {
bail!("expected single field input");
}
match expr {
Expr::AggregateFunction(AggregateFunction {
fun,
532 changes: 532 additions & 0 deletions arroyo-sql/src/optimizations.rs

Large diffs are not rendered by default.

1,188 changes: 105 additions & 1,083 deletions arroyo-sql/src/pipeline.rs

Large diffs are not rendered by default.

1,275 changes: 1,275 additions & 0 deletions arroyo-sql/src/plan_graph.rs

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions arroyo-sql/src/types.rs
Original file line number Diff line number Diff line change
@@ -23,6 +23,11 @@ pub struct StructDef {
pub name: Option<String>,
pub fields: Vec<StructField>,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct StructPair {
pub left: StructDef,
pub right: StructDef,
}

impl StructDef {
pub fn struct_name(&self) -> String {
21 changes: 0 additions & 21 deletions arroyo-worker/src/operators/join_with_expiration.rs
Original file line number Diff line number Diff line change
@@ -14,27 +14,6 @@ pub struct JoinWithExpiration<K: Key, T1: Data, T2: Data> {
_t: PhantomData<(K, T1, T2)>,
}

enum Side {
Left,
Right,
}

impl Side {
fn get_primary_side_char(&self) -> char {
match self {
Side::Left => 'l',
Side::Right => 'r',
}
}

fn get_secondary_side_char(&self) -> char {
match self {
Side::Left => 'r',
Side::Right => 'l',
}
}
}

#[co_process_fn(in_k1=K, in_t1=T1, in_k2=K, in_t2=T2, out_k=K, out_t=(T1,T2))]
impl<K: Key, T1: Data, T2: Data> JoinWithExpiration<K, T1, T2> {
fn name(&self) -> String {