Skip to content

Commit

Permalink
Plan and optimize generating expressions (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwylde authored Nov 1, 2024
1 parent 4f56baf commit 347c2f7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 16 deletions.
4 changes: 2 additions & 2 deletions crates/arroyo-api/src/pipelines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ pub(crate) async fn create_pipeline_int<'a>(
"job_id": job_id,
"parallelism": parallelism,
"has_udfs": udfs.first().map(|e| !e.definition.trim().is_empty()).unwrap_or(false),
"rust_udfs": udfs.iter().find(|e| e.language == UdfLanguage::Rust),
"python_udfs": udfs.iter().find(|e| e.language == UdfLanguage::Python),
"rust_udfs": udfs.iter().any(|e| e.language == UdfLanguage::Rust),
"python_udfs": udfs.iter().any(|e| e.language == UdfLanguage::Python),
// TODO: program features
"features": compiled.program.features(),
}),
Expand Down
11 changes: 7 additions & 4 deletions crates/arroyo-planner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{create_udf, SessionConfig};

use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::parser::{Parser, ParserError};
use datafusion::sql::{planner::ContextProvider, sqlparser, TableReference};

use datafusion::logical_expr::expr::ScalarFunction;
Expand Down Expand Up @@ -649,14 +649,17 @@ fn try_handle_set_variable(
Ok(false)
}

pub(crate) fn parse_sql(sql: &str) -> Result<Vec<Statement>, ParserError> {
let dialect = PostgreSqlDialect {};
Parser::parse_sql(&dialect, sql)
}

pub async fn parse_and_get_arrow_program(
query: String,
mut schema_provider: ArroyoSchemaProvider,
// TODO: use config
_config: SqlConfig,
) -> Result<CompiledSql> {
let dialect = PostgreSqlDialect {};

let mut config = SessionConfig::new();
config
.options_mut()
Expand All @@ -669,7 +672,7 @@ pub async fn parse_and_get_arrow_program(
.with_physical_optimizer_rules(vec![]);

let mut inserts = vec![];
for statement in Parser::parse_sql(&dialect, &query)? {
for statement in parse_sql(&query)? {
if try_handle_set_variable(&statement, &mut schema_provider)? {
continue;
}
Expand Down
53 changes: 43 additions & 10 deletions crates/arroyo-planner/src/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::extension::remote_table::RemoteTableExtension;
use crate::types::convert_data_type;
use crate::{
external::{ProcessingMode, SqlSource},
fields_with_qualifiers, ArroyoSchemaProvider, DFField,
fields_with_qualifiers, parse_sql, ArroyoSchemaProvider, DFField,
};
use crate::{rewrite_plan, DEFAULT_IDLE_TIME};
use arroyo_datastream::default_sink;
Expand Down Expand Up @@ -51,7 +51,6 @@ use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison;
use datafusion::optimizer::OptimizerRule;
use datafusion::sql::planner::PlannerContext;
use datafusion::sql::sqlparser;
use datafusion::sql::sqlparser::ast::Query;
use datafusion::{
Expand Down Expand Up @@ -504,6 +503,40 @@ fn value_to_inner_string(value: &Value) -> Result<String> {
}
}

fn plan_generating_expr(
expr: &sqlparser::ast::Expr,
name: &str,
schema: &DFSchema,
schema_provider: &ArroyoSchemaProvider,
session_state: &SessionState,
) -> Result<Expr, DataFusionError> {
let sql = format!("SELECT {} from {}", expr, name);
let statement = parse_sql(&sql)
.expect("generating expression should be valid")
.into_iter()
.next()
.expect("generating expression should produce one statement");

let mut schema_provider = schema_provider.clone();
schema_provider.insert_table(Table::MemoryTable {
name: name.to_string(),
fields: schema.fields().to_vec(),
logical_plan: None,
});

let plan = produce_optimized_plan(&statement, &schema_provider, session_state)?;

match plan {
LogicalPlan::Projection(p) => Ok(p.expr.into_iter().next().unwrap()),
p => {
unreachable!(
"top-level plan from generating expression should be a projection, but is {:?}",
p
);
}
}
}

#[derive(Default)]
struct MetadataFinder {
key: Option<String>,
Expand Down Expand Up @@ -546,8 +579,10 @@ impl<'a> TreeNodeVisitor<'a> for MetadataFinder {

impl Table {
fn schema_from_columns(
table_name: &str,
columns: &[ColumnDef],
schema_provider: &ArroyoSchemaProvider,
session_state: &SessionState,
) -> Result<Vec<FieldSpec>> {
let struct_field_pairs = columns
.iter()
Expand Down Expand Up @@ -600,18 +635,16 @@ impl Table {
HashMap::new(),
)?;

let sql_to_rel = SqlToRel::new(schema_provider);
struct_field_pairs
.into_iter()
.map(|(struct_field, generating_expression)| {
if let Some(generating_expression) = generating_expression {
// TODO: Implement automatic type coercion here, as we have elsewhere.
// It is done by calling the Analyzer which inserts CAST operators where necessary.

let df_expr = sql_to_rel.sql_to_expr(
generating_expression,
let df_expr = plan_generating_expr(
&generating_expression,
table_name,
&physical_schema,
&mut PlannerContext::default(),
schema_provider,
session_state,
)?;

let mut metadata_finder = MetadataFinder::default();
Expand Down Expand Up @@ -658,7 +691,7 @@ impl Table {
}

let connector = with_map.remove("connector");
let fields = Self::schema_from_columns(columns, schema_provider)?;
let fields = Self::schema_from_columns(&name, columns, schema_provider, session_state)?;

let primary_keys = columns
.iter()
Expand Down
10 changes: 10 additions & 0 deletions crates/arroyo-planner/src/test/queries/subscript_in_virtual.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
create table input (
length JSON,
diff INT GENERATED ALWAYS AS (extract_json(length, '$.old')[1]) STORED
) with (
connector = 'sse',
endpoint = 'https://localhost:9091',
format = 'json'
);

select * from input;
11 changes: 11 additions & 0 deletions crates/arroyo-planner/src/test/queries/virtual_bad_schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
--fail=Schema error: No field named notfield. Valid fields are input.length.
create table input (
length JSON,
diff INT GENERATED ALWAYS AS (notfield) STORED
) with (
connector = 'sse',
endpoint = 'https://localhost:9091',
format = 'json'
);

select * from input;

0 comments on commit 347c2f7

Please sign in to comment.