diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index c169791d0..a5bbb2c43 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -532,6 +532,7 @@ message ExecuteQueryParams { oneof query { bytes logical_plan = 1; string sql = 2; + bytes substrait_plan = 5; } oneof optional_session_id { string session_id = 3; diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index 28236ad04..3ed905a54 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -898,7 +898,7 @@ pub struct UpdateTaskStatusResult { pub struct ExecuteQueryParams { #[prost(message, repeated, tag = "4")] pub settings: ::prost::alloc::vec::Vec, - #[prost(oneof = "execute_query_params::Query", tags = "1, 2")] + #[prost(oneof = "execute_query_params::Query", tags = "1, 2, 5")] pub query: ::core::option::Option, #[prost(oneof = "execute_query_params::OptionalSessionId", tags = "3")] pub optional_session_id: ::core::option::Option< @@ -914,6 +914,8 @@ pub mod execute_query_params { LogicalPlan(::prost::alloc::vec::Vec), #[prost(string, tag = "2")] Sql(::prost::alloc::string::String), + #[prost(bytes, tag = "5")] + SubstraitPlan(::prost::alloc::vec::Vec), } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 59a6ca78f..756f5755e 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -53,6 +53,7 @@ configure_me = "0.4.0" dashmap = "5.4.0" datafusion = "18.0.0" datafusion-proto = "18.0.0" +datafusion-substrait = "18.0.0" etcd-client = { version = "0.10", optional = true } flatbuffers = { version = "22.9.29" } futures = "0.3" diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index bff078d01..722c4f8f9 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -35,6 +35,8 @@ use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::FileFormat; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; +use datafusion_substrait::serializer::deserialize_bytes; + use futures::TryStreamExt; use log::{debug, error, info, trace, warn}; use object_store::{local::LocalFileSystem, path::Path, ObjectStore}; @@ -44,6 +46,7 @@ use std::sync::Arc; use crate::scheduler_server::event::QueryStageSchedulerEvent; use datafusion::prelude::SessionContext; +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use std::time::{SystemTime, UNIX_EPOCH}; use tonic::{Request, Response, Status}; @@ -407,6 +410,20 @@ impl SchedulerGrpc }; let plan = match query { + Query::SubstraitPlan(bytes) => { + let plan = deserialize_bytes(bytes).await.map_err(|e| { + let msg = format!("Could not parse substrait plan: {e}"); + error!("{}", msg); + Status::internal(msg) + })?; + + let mut ctx = session_ctx.as_ref().clone(); + from_substrait_plan(&mut ctx, &plan).await.map_err(|e| { + let msg = format!("Could not parse substrait plan: {e}"); + error!("{}", msg); + Status::internal(msg) + })? + } Query::LogicalPlan(message) => T::try_decode(message.as_slice()) .and_then(|m| { m.try_into_logical_plan(