From 76980a53b80112fd6ed2f8efecb691036ce38e0b Mon Sep 17 00:00:00 2001 From: hozan23 <119854621+hozan23@users.noreply.github.com> Date: Tue, 4 Feb 2025 08:49:40 +0100 Subject: [PATCH] Move flight-sql code to datafusion-contrib/datafusion-flight-sql-server repo (#110) --- .release-plz.toml | 8 +- Cargo.toml | 16 +- datafusion-flight-sql-server/CHANGELOG.md | 38 - datafusion-flight-sql-server/Cargo.toml | 29 - datafusion-flight-sql-server/README.md | 41 - .../examples/flight-sql.rs | 84 -- .../examples/test.csv | 4 - datafusion-flight-sql-server/src/lib.rs | 5 - datafusion-flight-sql-server/src/service.rs | 1149 ----------------- datafusion-flight-sql-server/src/session.rs | 31 - datafusion-flight-sql-server/src/state.rs | 120 -- .../Cargo.toml | 15 - .../src/lib.rs | 105 -- 13 files changed, 3 insertions(+), 1642 deletions(-) delete mode 100644 datafusion-flight-sql-server/CHANGELOG.md delete mode 100644 datafusion-flight-sql-server/Cargo.toml delete mode 100644 datafusion-flight-sql-server/README.md delete mode 100644 datafusion-flight-sql-server/examples/flight-sql.rs delete mode 100644 datafusion-flight-sql-server/examples/test.csv delete mode 100644 datafusion-flight-sql-server/src/lib.rs delete mode 100644 datafusion-flight-sql-server/src/service.rs delete mode 100644 datafusion-flight-sql-server/src/session.rs delete mode 100644 datafusion-flight-sql-server/src/state.rs delete mode 100644 datafusion-flight-sql-table-provider/Cargo.toml delete mode 100644 datafusion-flight-sql-table-provider/src/lib.rs diff --git a/.release-plz.toml b/.release-plz.toml index 9927fe3..8b13789 100644 --- a/.release-plz.toml +++ b/.release-plz.toml @@ -1,7 +1 @@ -[[package]] -name = "datafusion-flight-sql-table-provider" -changelog_update = false -git_release_enable = false -publish = false -release = false -semver_check = false + diff --git a/Cargo.toml b/Cargo.toml index 4b2cf85..8177d26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,7 @@ [workspace] resolver = "2" -members = [ - "datafusion-federation", - "datafusion-flight-sql-server", - "datafusion-flight-sql-table-provider", -] +members = ["datafusion-federation"] [workspace.package] version = "0.3.5" @@ -15,19 +11,11 @@ readme = "README.md" repository = "https://github.com/datafusion-contrib/datafusion-federation" [workspace.dependencies] -arrow = "53.3" -arrow-flight = { version = "53.3", features = ["flight-sql-experimental"] } arrow-json = "53.3" async-stream = "0.3.5" async-trait = "0.1.83" datafusion = "44.0.0" datafusion-federation = { path = "./datafusion-federation", version = "0.3.5" } -datafusion-substrait = "44.0.0" futures = "0.3.31" tokio = { version = "1.41", features = ["full"] } -tonic = { version = "0.12", features = [ - "tls", - "transport", - "codegen", - "prost", -] } + diff --git a/datafusion-flight-sql-server/CHANGELOG.md b/datafusion-flight-sql-server/CHANGELOG.md deleted file mode 100644 index f83e363..0000000 --- a/datafusion-flight-sql-server/CHANGELOG.md +++ /dev/null @@ -1,38 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [Unreleased] - -## [0.4.6](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-flight-sql-server-v0.4.5...datafusion-flight-sql-server-v0.4.6) - 2025-01-20 - -### Other - -- updated the following local packages: datafusion-federation - -## [0.4.5](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-flight-sql-server-v0.4.4...datafusion-flight-sql-server-v0.4.5) - 2025-01-12 - -### Other - -- upgrade datafusion to 44 (#103) - -## [0.4.4](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-flight-sql-server-v0.4.3...datafusion-flight-sql-server-v0.4.4) - 2025-01-04 - -### Other - -- updated the following local packages: datafusion-federation - -## [0.4.3](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-flight-sql-server-v0.4.2...datafusion-flight-sql-server-v0.4.3) - 2024-12-23 - -### Other - -- add README.md (#91) - -## [0.4.2](https://github.com/datafusion-contrib/datafusion-federation/compare/datafusion-flight-sql-server-v0.4.1...datafusion-flight-sql-server-v0.4.2) - 2024-12-03 - -### Added - -- support prepared statement parameters ([#81](https://github.com/datafusion-contrib/datafusion-federation/pull/81)) diff --git a/datafusion-flight-sql-server/Cargo.toml b/datafusion-flight-sql-server/Cargo.toml deleted file mode 100644 index 913c17a..0000000 --- a/datafusion-flight-sql-server/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "datafusion-flight-sql-server" -version = "0.4.6" -edition.workspace = true -license.workspace = true -repository.workspace = true -description = "Datafusion flight sql server." -readme = "README.md" - -[lib] -name = "datafusion_flight_sql_server" -path = "src/lib.rs" - -[dependencies] -arrow-flight.workspace = true -arrow.workspace = true -datafusion-federation = { workspace = true, features = ["sql"] } -datafusion-substrait = { workspace = true, features = ["protoc"] } -datafusion.workspace = true -futures.workspace = true -log = "0.4.22" -once_cell = "1.19.0" -prost = "0.13.1" -tonic.workspace = true -async-trait.workspace = true - -[dev-dependencies] -tokio.workspace = true -datafusion-flight-sql-table-provider = { path = "../datafusion-flight-sql-table-provider" } diff --git a/datafusion-flight-sql-server/README.md b/datafusion-flight-sql-server/README.md deleted file mode 100644 index f72903e..0000000 --- a/datafusion-flight-sql-server/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# DataFusion Flight SQL Server - -The `datafusion-flight-sql-server` is a Flight SQL server that implements the -necessary endpoints to use DataFusion as the query engine. - -## Getting Started - -To use `datafusion-flight-sql-server` in your Rust project, run: - -```sh -$ cargo add datafusion-flight-sql-server -``` - -## Example - -Here's a basic example of setting up a Flight SQL server: - -```rust -use datafusion_flight_sql_server::service::FlightSqlService; -use datafusion::{ - execution::{ - context::SessionContext, - options::CsvReadOptions, - }, -}; - -async { - let dsn: String = "0.0.0.0:50051".to_string(); - let remote_ctx = SessionContext::new(); - remote_ctx - .register_csv("test", "./examples/test.csv", CsvReadOptions::new()) - .await.expect("Register csv"); - - FlightSqlService::new(remote_ctx.state()).serve(dsn.clone()) - .await - .expect("Run flight sql service"); - -}; -``` - -This example sets up a Flight SQL server listening on `127.0.0.1:50051`. diff --git a/datafusion-flight-sql-server/examples/flight-sql.rs b/datafusion-flight-sql-server/examples/flight-sql.rs deleted file mode 100644 index 021227a..0000000 --- a/datafusion-flight-sql-server/examples/flight-sql.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::{sync::Arc, time::Duration}; - -use arrow_flight::sql::client::FlightSqlServiceClient; -use datafusion::{ - catalog::SchemaProvider, - error::{DataFusionError, Result}, - execution::{ - context::{SessionContext, SessionState}, - options::CsvReadOptions, - }, -}; -use datafusion_federation::sql::{SQLFederationProvider, SQLSchemaProvider}; -use datafusion_flight_sql_server::service::FlightSqlService; -use datafusion_flight_sql_table_provider::FlightSQLExecutor; -use tokio::time::sleep; -use tonic::transport::Endpoint; - -#[tokio::main] -async fn main() -> Result<()> { - let dsn: String = "0.0.0.0:50051".to_string(); - let remote_ctx = SessionContext::new(); - remote_ctx - .register_csv("test", "./examples/test.csv", CsvReadOptions::new()) - .await?; - - // Remote context - tokio::spawn(async move { - FlightSqlService::new(remote_ctx.state()) - .serve(dsn.clone()) - .await - .unwrap(); - }); - - // Wait for server to run - sleep(Duration::from_secs(3)).await; - - // Local context - let state = datafusion_federation::default_session_state(); - let known_tables: Vec = ["test"].iter().map(|&x| x.into()).collect(); - - // Register schema - // TODO: table inference - let dsn: String = "http://localhost:50051".to_string(); - let client = new_client(dsn.clone()).await?; - let executor = Arc::new(FlightSQLExecutor::new(dsn, client)); - let provider = Arc::new(SQLFederationProvider::new(executor)); - let schema_provider = - Arc::new(SQLSchemaProvider::new_with_tables(provider, known_tables).await?); - overwrite_default_schema(&state, schema_provider)?; - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT * from test"#; - let df = ctx.sql(query).await?; - - // let explain = df.clone().explain(true, false)?; - // explain.show().await?; - - df.show().await -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} - -/// Creates a new [FlightSqlServiceClient] for the passed endpoint. Completes the relevant auth configurations -/// or handshake as appropriate for the passed [FlightSQLAuth] variant. -async fn new_client(dsn: String) -> Result> { - let endpoint = Endpoint::new(dsn).map_err(tx_error_to_df)?; - let channel = endpoint.connect().await.map_err(tx_error_to_df)?; - Ok(FlightSqlServiceClient::new(channel)) -} - -fn tx_error_to_df(err: tonic::transport::Error) -> DataFusionError { - DataFusionError::External(format!("failed to connect: {err:?}").into()) -} diff --git a/datafusion-flight-sql-server/examples/test.csv b/datafusion-flight-sql-server/examples/test.csv deleted file mode 100644 index 811d276..0000000 --- a/datafusion-flight-sql-server/examples/test.csv +++ /dev/null @@ -1,4 +0,0 @@ -foo,bar -a,1 -b,2 -c,3 \ No newline at end of file diff --git a/datafusion-flight-sql-server/src/lib.rs b/datafusion-flight-sql-server/src/lib.rs deleted file mode 100644 index 101d335..0000000 --- a/datafusion-flight-sql-server/src/lib.rs +++ /dev/null @@ -1,5 +0,0 @@ -#![doc = include_str!("../README.md")] - -pub mod service; -pub mod session; -pub mod state; diff --git a/datafusion-flight-sql-server/src/service.rs b/datafusion-flight-sql-server/src/service.rs deleted file mode 100644 index cc97962..0000000 --- a/datafusion-flight-sql-server/src/service.rs +++ /dev/null @@ -1,1149 +0,0 @@ -use std::{collections::BTreeMap, pin::Pin, sync::Arc}; - -use arrow::{ - array::{ArrayRef, RecordBatch, StringArray}, - compute::concat_batches, - datatypes::{DataType, Field, SchemaBuilder, SchemaRef}, - error::ArrowError, - ipc::{ - reader::StreamReader, - writer::{IpcWriteOptions, StreamWriter}, - }, -}; -use arrow_flight::{ - decode::{DecodedPayload, FlightDataDecoder}, - sql::{ - self, - server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream}, - ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, - ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, - ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, - ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, - CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, - CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan, - CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo, - TicketStatementQuery, - }, -}; -use arrow_flight::{ - encode::FlightDataEncoderBuilder, - error::FlightError, - flight_service_server::{FlightService, FlightServiceServer}, - Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, - IpcMessage, SchemaAsIpc, Ticket, -}; -use datafusion::{ - common::{arrow::datatypes::Schema, ParamValues}, - dataframe::DataFrame, - datasource::TableType, - error::{DataFusionError, Result as DataFusionResult}, - execution::context::{SQLOptions, SessionContext, SessionState}, - logical_expr::LogicalPlan, - physical_plan::SendableRecordBatchStream, - scalar::ScalarValue, -}; -use datafusion_substrait::{ - logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes, -}; -use futures::{Stream, StreamExt, TryStreamExt}; -use log::info; -use once_cell::sync::Lazy; -use prost::bytes::Bytes; -use prost::Message; -use tonic::transport::Server; -use tonic::{Request, Response, Status, Streaming}; - -use super::session::{SessionStateProvider, StaticSessionStateProvider}; -use super::state::{CommandTicket, QueryHandle}; - -type Result = std::result::Result; - -/// FlightSqlService is a basic stateless FlightSqlService implementation. -pub struct FlightSqlService { - provider: Box, - sql_options: Option, -} - -impl FlightSqlService { - /// Creates a new FlightSqlService with a static SessionState. - pub fn new(state: SessionState) -> Self { - Self::new_with_provider(Box::new(StaticSessionStateProvider::new(state))) - } - - /// Creates a new FlightSqlService with a SessionStateProvider. - pub fn new_with_provider(provider: Box) -> Self { - Self { - provider, - sql_options: None, - } - } - - /// Replaces the sql_options with the provided options. - /// These options are used to verify all SQL queries. - /// When None the default [`SQLOptions`] are used. - pub fn with_sql_options(self, sql_options: SQLOptions) -> Self { - Self { - sql_options: Some(sql_options), - ..self - } - } - - // Federate substrait plans instead of SQL - // pub fn with_substrait() -> Self { - // TODO: Substrait federation - // } - - // Serves straightforward on the specified address. - pub async fn serve(self, addr: String) -> Result<(), Box> { - let addr = addr.parse()?; - info!("Listening on {addr:?}"); - - let svc = FlightServiceServer::new(self); - - Ok(Server::builder().add_service(svc).serve(addr).await?) - } - - async fn new_context( - &self, - request: Request, - ) -> Result<(Request, FlightSqlSessionContext)> { - let (metadata, extensions, msg) = request.into_parts(); - let inspect_request = Request::from_parts(metadata, extensions, ()); - - let state = self.provider.new_context(&inspect_request).await?; - let ctx = SessionContext::new_with_state(state); - - let (metadata, extensions, _) = inspect_request.into_parts(); - Ok(( - Request::from_parts(metadata, extensions, msg), - FlightSqlSessionContext { - inner: ctx, - sql_options: self.sql_options, - }, - )) - } -} - -/// The schema for GetTableTypes -static GET_TABLE_TYPES_SCHEMA: Lazy = Lazy::new(|| { - //TODO: Move this into arrow-flight itself, similar to the builder pattern for CommandGetCatalogs and CommandGetDbSchemas - Arc::new(Schema::new(vec![Field::new( - "table_type", - DataType::Utf8, - false, - )])) -}); - -struct FlightSqlSessionContext { - inner: SessionContext, - sql_options: Option, -} - -impl FlightSqlSessionContext { - async fn sql_to_logical_plan(&self, sql: &str) -> DataFusionResult { - let plan = self.inner.state().create_logical_plan(sql).await?; - let verifier = self.sql_options.unwrap_or_default(); - verifier.verify_plan(&plan)?; - Ok(plan) - } - - async fn execute_sql(&self, sql: &str) -> DataFusionResult { - let plan = self.sql_to_logical_plan(sql).await?; - self.execute_logical_plan(plan).await - } - - async fn execute_logical_plan( - &self, - plan: LogicalPlan, - ) -> DataFusionResult { - self.inner - .execute_logical_plan(plan) - .await? - .execute_stream() - .await - } -} - -#[tonic::async_trait] -impl ArrowFlightSqlService for FlightSqlService { - type FlightService = FlightSqlService; - - async fn do_handshake( - &self, - _request: Request>, - ) -> Result> + Send>>>> { - info!("do_handshake"); - // Favor middleware over handshake - // https://github.com/apache/arrow/issues/23836 - // https://github.com/apache/arrow/issues/25848 - Err(Status::unimplemented("handshake is not supported")) - } - - async fn do_get_fallback( - &self, - request: Request, - _message: Any, - ) -> Result::DoGetStream>> { - let (request, ctx) = self.new_context(request).await?; - - let ticket = CommandTicket::try_decode(request.into_inner().ticket) - .map_err(flight_error_to_status)?; - - match ticket.command { - sql::Command::CommandStatementQuery(CommandStatementQuery { query, .. }) => { - // print!("Query: {query}\n"); - - let stream = ctx.execute_sql(&query).await.map_err(df_error_to_status)?; - let arrow_schema = stream.schema(); - let arrow_stream = stream.map(|i| { - let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?; - Ok(batch) - }); - - let flight_data_stream = FlightDataEncoderBuilder::new() - .with_schema(arrow_schema) - .build(arrow_stream) - .map_err(flight_error_to_status) - .boxed(); - - Ok(Response::new(flight_data_stream)) - } - sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery { - prepared_statement_handle, - }) => { - let handle = QueryHandle::try_decode(prepared_statement_handle)?; - - let mut plan = ctx - .sql_to_logical_plan(handle.query()) - .await - .map_err(df_error_to_status)?; - - if let Some(param_values) = - decode_param_values(handle.parameters()).map_err(arrow_error_to_status)? - { - plan = plan - .with_param_values(param_values) - .map_err(df_error_to_status)?; - } - - let stream = ctx - .execute_logical_plan(plan) - .await - .map_err(df_error_to_status)?; - let arrow_schema = stream.schema(); - let arrow_stream = stream.map(|i| { - let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?; - Ok(batch) - }); - - let flight_data_stream = FlightDataEncoderBuilder::new() - .with_schema(arrow_schema) - .build(arrow_stream) - .map_err(flight_error_to_status) - .boxed(); - - Ok(Response::new(flight_data_stream)) - } - sql::Command::CommandStatementSubstraitPlan(CommandStatementSubstraitPlan { - plan, - .. - }) => { - let substrait_bytes = &plan - .ok_or(Status::invalid_argument( - "Expected substrait plan, found None", - ))? - .plan; - - let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?; - - let state = ctx.inner.state(); - let df = DataFrame::new(state, plan); - - let stream = df.execute_stream().await.map_err(df_error_to_status)?; - let arrow_schema = stream.schema(); - let arrow_stream = stream.map(|i| { - let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?; - Ok(batch) - }); - - let flight_data_stream = FlightDataEncoderBuilder::new() - .with_schema(arrow_schema) - .build(arrow_stream) - .map_err(flight_error_to_status) - .boxed(); - - Ok(Response::new(flight_data_stream)) - } - _ => { - return Err(Status::internal(format!( - "statement handle not found: {:?}", - ticket.command - ))); - } - } - } - - async fn get_flight_info_statement( - &self, - query: CommandStatementQuery, - request: Request, - ) -> Result> { - let (request, ctx) = self.new_context(request).await?; - - let sql = &query.query; - info!("get_flight_info_statement with query={sql}"); - - let flight_descriptor = request.into_inner(); - - let plan = ctx - .sql_to_logical_plan(sql) - .await - .map_err(df_error_to_status)?; - - let dataset_schema = get_schema_for_plan(&plan); - - // Form the response ticket (that the client will pass back to DoGet) - let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query)) - .try_encode() - .map_err(flight_error_to_status)?; - - let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket }); - - let flight_info = FlightInfo::new() - .with_endpoint(endpoint) - // return descriptor we were passed - .with_descriptor(flight_descriptor) - .try_with_schema(dataset_schema.as_ref()) - .map_err(arrow_error_to_status)?; - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_substrait_plan( - &self, - query: CommandStatementSubstraitPlan, - request: Request, - ) -> Result> { - info!("get_flight_info_substrait_plan"); - let (request, ctx) = self.new_context(request).await?; - - let substrait_bytes = &query - .plan - .as_ref() - .ok_or(Status::invalid_argument( - "Expected substrait plan, found None", - ))? - .plan; - - let plan = parse_substrait_bytes(&ctx, substrait_bytes).await?; - - let flight_descriptor = request.into_inner(); - - let dataset_schema = get_schema_for_plan(&plan); - - // Form the response ticket (that the client will pass back to DoGet) - let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query)) - .try_encode() - .map_err(flight_error_to_status)?; - - let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket }); - - let flight_info = FlightInfo::new() - .with_endpoint(endpoint) - // return descriptor we were passed - .with_descriptor(flight_descriptor) - .try_with_schema(dataset_schema.as_ref()) - .map_err(arrow_error_to_status)?; - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_prepared_statement( - &self, - cmd: CommandPreparedStatementQuery, - request: Request, - ) -> Result> { - let (request, ctx) = self.new_context(request).await?; - - let handle = QueryHandle::try_decode(cmd.prepared_statement_handle.clone()) - .map_err(|e| Status::internal(format!("Error decoding handle: {e}")))?; - - info!("get_flight_info_prepared_statement with handle={handle}"); - - let flight_descriptor = request.into_inner(); - - let sql = handle.query(); - let plan = ctx - .sql_to_logical_plan(sql) - .await - .map_err(df_error_to_status)?; - - let dataset_schema = get_schema_for_plan(&plan); - - // Form the response ticket (that the client will pass back to DoGet) - let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd)) - .try_encode() - .map_err(flight_error_to_status)?; - - let endpoint = FlightEndpoint::new().with_ticket(Ticket { ticket }); - - let flight_info = FlightInfo::new() - .with_endpoint(endpoint) - // return descriptor we were passed - .with_descriptor(flight_descriptor) - .try_with_schema(dataset_schema.as_ref()) - .map_err(arrow_error_to_status)?; - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_catalogs( - &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result> { - info!("get_flight_info_catalogs"); - let (request, _ctx) = self.new_context(request).await?; - - let flight_descriptor = request.into_inner(); - let ticket = Ticket { - ticket: query.as_any().encode_to_vec().into(), - }; - let endpoint = FlightEndpoint::new().with_ticket(ticket); - - let flight_info = FlightInfo::new() - .try_with_schema(&query.into_builder().schema()) - .map_err(arrow_error_to_status)? - .with_endpoint(endpoint) - .with_descriptor(flight_descriptor); - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_schemas( - &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result> { - info!("get_flight_info_schemas"); - let (request, _ctx) = self.new_context(request).await?; - let flight_descriptor = request.into_inner(); - let ticket = Ticket { - ticket: query.as_any().encode_to_vec().into(), - }; - let endpoint = FlightEndpoint::new().with_ticket(ticket); - - let flight_info = FlightInfo::new() - .try_with_schema(&query.into_builder().schema()) - .map_err(arrow_error_to_status)? - .with_endpoint(endpoint) - .with_descriptor(flight_descriptor); - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_tables( - &self, - query: CommandGetTables, - request: Request, - ) -> Result> { - info!("get_flight_info_tables"); - let (request, _ctx) = self.new_context(request).await?; - - let flight_descriptor = request.into_inner(); - let ticket = Ticket { - ticket: query.as_any().encode_to_vec().into(), - }; - let endpoint = FlightEndpoint::new().with_ticket(ticket); - - let flight_info = FlightInfo::new() - .try_with_schema(&query.into_builder().schema()) - .map_err(arrow_error_to_status)? - .with_endpoint(endpoint) - .with_descriptor(flight_descriptor); - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_table_types( - &self, - query: CommandGetTableTypes, - request: Request, - ) -> Result> { - info!("get_flight_info_table_types"); - let (request, _ctx) = self.new_context(request).await?; - - let flight_descriptor = request.into_inner(); - let ticket = Ticket { - ticket: query.as_any().encode_to_vec().into(), - }; - let endpoint = FlightEndpoint::new().with_ticket(ticket); - - let flight_info = FlightInfo::new() - .try_with_schema(&GET_TABLE_TYPES_SCHEMA) - .map_err(arrow_error_to_status)? - .with_endpoint(endpoint) - .with_descriptor(flight_descriptor); - - Ok(Response::new(flight_info)) - } - - async fn get_flight_info_sql_info( - &self, - _query: CommandGetSqlInfo, - request: Request, - ) -> Result> { - info!("get_flight_info_sql_info"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement CommandGetSqlInfo")) - } - - async fn get_flight_info_primary_keys( - &self, - _query: CommandGetPrimaryKeys, - request: Request, - ) -> Result> { - info!("get_flight_info_primary_keys"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement get_flight_info_primary_keys", - )) - } - - async fn get_flight_info_exported_keys( - &self, - _query: CommandGetExportedKeys, - request: Request, - ) -> Result> { - info!("get_flight_info_exported_keys"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement get_flight_info_exported_keys", - )) - } - - async fn get_flight_info_imported_keys( - &self, - _query: CommandGetImportedKeys, - request: Request, - ) -> Result> { - info!("get_flight_info_imported_keys"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement get_flight_info_imported_keys", - )) - } - - async fn get_flight_info_cross_reference( - &self, - _query: CommandGetCrossReference, - request: Request, - ) -> Result> { - info!("get_flight_info_cross_reference"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement get_flight_info_cross_reference", - )) - } - - async fn get_flight_info_xdbc_type_info( - &self, - _query: CommandGetXdbcTypeInfo, - request: Request, - ) -> Result> { - info!("get_flight_info_xdbc_type_info"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement get_flight_info_xdbc_type_info", - )) - } - - async fn do_get_statement( - &self, - _ticket: TicketStatementQuery, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_statement"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_statement")) - } - - async fn do_get_prepared_statement( - &self, - _query: CommandPreparedStatementQuery, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_prepared_statement"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_prepared_statement")) - } - - async fn do_get_catalogs( - &self, - query: CommandGetCatalogs, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_catalogs"); - let (_request, ctx) = self.new_context(request).await?; - let catalog_names = ctx.inner.catalog_names(); - - let mut builder = query.into_builder(); - for catalog_name in &catalog_names { - builder.append(catalog_name); - } - let schema = builder.schema(); - let batch = builder.build(); - let stream = FlightDataEncoderBuilder::new() - .with_schema(schema) - .build(futures::stream::once(async { batch })) - .map_err(Status::from); - Ok(Response::new(Box::pin(stream))) - } - - async fn do_get_schemas( - &self, - query: CommandGetDbSchemas, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_schemas"); - let (_request, ctx) = self.new_context(request).await?; - let catalog_name = query.catalog.clone(); - // Append all schemas to builder, the builder handles applying the filters. - let mut builder = query.into_builder(); - if let Some(catalog_name) = &catalog_name { - if let Some(catalog) = ctx.inner.catalog(catalog_name) { - for schema_name in &catalog.schema_names() { - builder.append(catalog_name, schema_name); - } - } - }; - - let schema = builder.schema(); - let batch = builder.build(); - let stream = FlightDataEncoderBuilder::new() - .with_schema(schema) - .build(futures::stream::once(async { batch })) - .map_err(Status::from); - Ok(Response::new(Box::pin(stream))) - } - - async fn do_get_tables( - &self, - query: CommandGetTables, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_tables"); - let (_request, ctx) = self.new_context(request).await?; - let catalog_name = query.catalog.clone(); - let mut builder = query.into_builder(); - // Append all schemas/tables to builder, the builder handles applying the filters. - if let Some(catalog_name) = &catalog_name { - if let Some(catalog) = ctx.inner.catalog(catalog_name) { - for schema_name in &catalog.schema_names() { - if let Some(schema) = catalog.schema(schema_name) { - for table_name in &schema.table_names() { - if let Some(table) = - schema.table(table_name).await.map_err(df_error_to_status)? - { - builder - .append( - catalog_name, - schema_name, - table_name, - table.table_type().to_string(), - &table.schema(), - ) - .map_err(flight_error_to_status)?; - } - } - } - } - } - }; - - let schema = builder.schema(); - let batch = builder.build(); - let stream = FlightDataEncoderBuilder::new() - .with_schema(schema) - .build(futures::stream::once(async { batch })) - .map_err(Status::from); - Ok(Response::new(Box::pin(stream))) - } - - async fn do_get_table_types( - &self, - _query: CommandGetTableTypes, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_table_types"); - let (_, _) = self.new_context(request).await?; - - // Report all variants of table types that datafusion uses. - let table_types: ArrayRef = Arc::new(StringArray::from( - vec![TableType::Base, TableType::View, TableType::Temporary] - .into_iter() - .map(|tt| tt.to_string()) - .collect::>(), - )); - - let batch = RecordBatch::try_from_iter(vec![("table_type", table_types)]).unwrap(); - - let stream = FlightDataEncoderBuilder::new() - .with_schema(GET_TABLE_TYPES_SCHEMA.clone()) - .build(futures::stream::once(async { Ok(batch) })) - .map_err(Status::from); - Ok(Response::new(Box::pin(stream))) - } - - async fn do_get_sql_info( - &self, - _query: CommandGetSqlInfo, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_sql_info"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_sql_info")) - } - - async fn do_get_primary_keys( - &self, - _query: CommandGetPrimaryKeys, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_primary_keys"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_primary_keys")) - } - - async fn do_get_exported_keys( - &self, - _query: CommandGetExportedKeys, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_exported_keys"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_exported_keys")) - } - - async fn do_get_imported_keys( - &self, - _query: CommandGetImportedKeys, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_imported_keys"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_imported_keys")) - } - - async fn do_get_cross_reference( - &self, - _query: CommandGetCrossReference, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_cross_reference"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_cross_reference")) - } - - async fn do_get_xdbc_type_info( - &self, - _query: CommandGetXdbcTypeInfo, - request: Request, - ) -> Result::DoGetStream>> { - info!("do_get_xdbc_type_info"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_get_xdbc_type_info")) - } - - async fn do_put_statement_update( - &self, - _ticket: CommandStatementUpdate, - request: Request, - ) -> Result { - info!("do_put_statement_update"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_put_statement_update")) - } - - async fn do_put_prepared_statement_query( - &self, - query: CommandPreparedStatementQuery, - request: Request, - ) -> Result { - info!("do_put_prepared_statement_query"); - let (request, _) = self.new_context(request).await?; - - let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?; - - info!( - "do_action_create_prepared_statement query={:?}", - handle.query() - ); - // Collect request flight data as parameters - // Decode and encode as a single ipc stream - let mut decoder = - FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error)); - let schema = decode_schema(&mut decoder).await?; - let mut parameters = Vec::new(); - let mut encoder = - StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?; - let mut total_rows = 0; - while let Some(msg) = decoder.try_next().await? { - match msg.payload { - DecodedPayload::None => {} - DecodedPayload::Schema(_) => { - return Err(Status::invalid_argument( - "parameter flight data must contain a single schema", - )); - } - DecodedPayload::RecordBatch(record_batch) => { - total_rows += record_batch.num_rows(); - encoder - .write(&record_batch) - .map_err(arrow_error_to_status)?; - } - } - } - if total_rows > 1 { - return Err(Status::invalid_argument( - "parameters should contain a single row", - )); - } - - handle.set_parameters(Some(parameters.into())); - - let res = DoPutPreparedStatementResult { - prepared_statement_handle: Some(Bytes::from(handle)), - }; - - Ok(res) - } - - async fn do_put_prepared_statement_update( - &self, - _handle: CommandPreparedStatementUpdate, - request: Request, - ) -> Result { - info!("do_put_prepared_statement_update"); - let (_, _) = self.new_context(request).await?; - - // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function - // and we are required to return some row count here - Ok(-1) - } - - async fn do_put_substrait_plan( - &self, - _query: CommandStatementSubstraitPlan, - request: Request, - ) -> Result { - info!("do_put_prepared_statement_update"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement do_put_prepared_statement_update", - )) - } - - async fn do_action_create_prepared_statement( - &self, - query: ActionCreatePreparedStatementRequest, - request: Request, - ) -> Result { - let (_, ctx) = self.new_context(request).await?; - - let sql = query.query.clone(); - info!( - "do_action_create_prepared_statement query={:?}", - query.query - ); - - let plan = ctx - .sql_to_logical_plan(sql.as_str()) - .await - .map_err(df_error_to_status)?; - - let dataset_schema = get_schema_for_plan(&plan); - let parameter_schema = parameter_schema_for_plan(&plan)?; - - let dataset_schema = - encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?; - let parameter_schema = - encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?; - - let handle = QueryHandle::new(sql, None); - - let res = ActionCreatePreparedStatementResult { - prepared_statement_handle: Bytes::from(handle), - dataset_schema, - parameter_schema, - }; - - Ok(res) - } - - async fn do_action_close_prepared_statement( - &self, - query: ActionClosePreparedStatementRequest, - request: Request, - ) -> Result<(), Status> { - let (_, _) = self.new_context(request).await?; - - let handle = query.prepared_statement_handle.as_ref(); - if let Ok(handle) = std::str::from_utf8(handle) { - info!( - "do_action_close_prepared_statement with handle {:?}", - handle - ); - - // NOP since stateless - } - Ok(()) - } - - async fn do_action_create_prepared_substrait_plan( - &self, - _query: ActionCreatePreparedSubstraitPlanRequest, - request: Request, - ) -> Result { - info!("do_action_create_prepared_substrait_plan"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented( - "Implement do_action_create_prepared_substrait_plan", - )) - } - - async fn do_action_begin_transaction( - &self, - _query: ActionBeginTransactionRequest, - request: Request, - ) -> Result { - let (_, _) = self.new_context(request).await?; - - info!("do_action_begin_transaction"); - Err(Status::unimplemented( - "Implement do_action_begin_transaction", - )) - } - - async fn do_action_end_transaction( - &self, - _query: ActionEndTransactionRequest, - request: Request, - ) -> Result<(), Status> { - info!("do_action_end_transaction"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_action_end_transaction")) - } - - async fn do_action_begin_savepoint( - &self, - _query: ActionBeginSavepointRequest, - request: Request, - ) -> Result { - info!("do_action_begin_savepoint"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_action_begin_savepoint")) - } - - async fn do_action_end_savepoint( - &self, - _query: ActionEndSavepointRequest, - request: Request, - ) -> Result<(), Status> { - info!("do_action_end_savepoint"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_action_end_savepoint")) - } - - async fn do_action_cancel_query( - &self, - _query: ActionCancelQueryRequest, - request: Request, - ) -> Result { - info!("do_action_cancel_query"); - let (_, _) = self.new_context(request).await?; - - Err(Status::unimplemented("Implement do_action_cancel_query")) - } - - async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} -} - -/// Takes a substrait plan serialized as [Bytes] and deserializes this to -/// a Datafusion [LogicalPlan] -async fn parse_substrait_bytes( - ctx: &FlightSqlSessionContext, - substrait: &Bytes, -) -> Result { - let substrait_plan = deserialize_bytes(substrait.to_vec()) - .await - .map_err(df_error_to_status)?; - - from_substrait_plan(&ctx.inner.state(), &substrait_plan) - .await - .map_err(df_error_to_status) -} - -/// Encodes the schema IPC encoded (schema_bytes) -fn encode_schema(schema: &Schema) -> std::result::Result { - let options = IpcWriteOptions::default(); - - // encode the schema into the correct form - let message: Result = SchemaAsIpc::new(schema, &options).try_into(); - - let IpcMessage(schema) = message?; - - Ok(schema) -} - -/// Return the schema for the specified logical plan -fn get_schema_for_plan(logical_plan: &LogicalPlan) -> SchemaRef { - // gather real schema, but only - let schema = Schema::from(logical_plan.schema().as_ref()).into(); - - // Use an empty FlightDataEncoder to determine the schema of the encoded flight data. - // This is necessary as the schema can change based on dictionary hydration behavior. - let flight_data_stream = FlightDataEncoderBuilder::new() - // Inform the builder of the input stream schema - .with_schema(schema) - .build(futures::stream::iter([])); - - // Retrieve the schema of the encoded data - flight_data_stream - .known_schema() - .expect("flight data schema should be known when explicitly provided via `with_schema`") -} - -fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result { - let parameters = plan - .get_parameter_types() - .map_err(df_error_to_status)? - .into_iter() - .map(|(name, dt)| { - dt.map(|dt| (name.clone(), dt)).ok_or_else(|| { - Status::internal(format!( - "unable to determine type of query parameter {name}" - )) - }) - }) - // Collect into BTreeMap so we get a consistent order of the parameters - .collect::, Status>>()?; - - let mut builder = SchemaBuilder::new(); - parameters - .into_iter() - .for_each(|(name, typ)| builder.push(Field::new(name, typ, false))); - Ok(builder.finish().into()) -} - -fn arrow_error_to_status(err: ArrowError) -> Status { - Status::internal(format!("{err:?}")) -} - -fn flight_error_to_status(err: FlightError) -> Status { - Status::internal(format!("{err:?}")) -} - -fn df_error_to_status(err: DataFusionError) -> Status { - Status::internal(format!("{err:?}")) -} - -fn status_to_flight_error(status: Status) -> FlightError { - FlightError::Tonic(status) -} - -async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result { - while let Some(msg) = decoder.try_next().await? { - match msg.payload { - DecodedPayload::None => {} - DecodedPayload::Schema(schema) => { - return Ok(schema); - } - DecodedPayload::RecordBatch(_) => { - return Err(Status::invalid_argument( - "parameter flight data must have a known schema", - )); - } - } - } - - Err(Status::invalid_argument( - "parameter flight data must have a schema", - )) -} - -// Decode parameter ipc stream as ParamValues -fn decode_param_values( - parameters: Option<&[u8]>, -) -> Result, arrow::error::ArrowError> { - parameters - .map(|parameters| { - let decoder = StreamReader::try_new(parameters, None)?; - let schema = decoder.schema(); - let batches = decoder.into_iter().collect::, _>>()?; - let batch = concat_batches(&schema, batches.iter())?; - Ok(record_to_param_values(&batch)?) - }) - .transpose() -} - -// Converts a record batch with a single row into ParamValues -fn record_to_param_values(batch: &RecordBatch) -> Result { - let mut param_values: Vec<(String, Option, ScalarValue)> = Vec::new(); - - let mut is_list = true; - for col_index in 0..batch.num_columns() { - let array = batch.column(col_index); - let scalar = ScalarValue::try_from_array(array, 0)?; - let name = batch - .schema_ref() - .field(col_index) - .name() - .trim_start_matches('$') - .to_string(); - let index = name.parse().ok(); - is_list &= index.is_some(); - param_values.push((name, index, scalar)); - } - if is_list { - let mut values: Vec<(Option, ScalarValue)> = param_values - .into_iter() - .map(|(_name, index, value)| (index, value)) - .collect(); - values.sort_by_key(|(index, _value)| *index); - Ok(values - .into_iter() - .map(|(_index, value)| value) - .collect::>() - .into()) - } else { - Ok(param_values - .into_iter() - .map(|(name, _index, value)| (name, value)) - .collect::>() - .into()) - } -} diff --git a/datafusion-flight-sql-server/src/session.rs b/datafusion-flight-sql-server/src/session.rs deleted file mode 100644 index 8b10df7..0000000 --- a/datafusion-flight-sql-server/src/session.rs +++ /dev/null @@ -1,31 +0,0 @@ -use async_trait::async_trait; -use datafusion::execution::context::SessionState; -use tonic::{Request, Status}; - -type Result = std::result::Result; - -// SessionStateProvider is a trait used to provide a SessionState for a given -// request. -#[async_trait] -pub trait SessionStateProvider: Sync + Send { - async fn new_context(&self, request: &Request<()>) -> Result; -} - -// StaticSessionStateProvider is a simple implementation of SessionStateProvider that -// uses a static SessionState. -pub(crate) struct StaticSessionStateProvider { - state: SessionState, -} - -impl StaticSessionStateProvider { - pub fn new(state: SessionState) -> Self { - Self { state } - } -} - -#[async_trait] -impl SessionStateProvider for StaticSessionStateProvider { - async fn new_context(&self, _request: &Request<()>) -> Result { - Ok(self.state.clone()) - } -} diff --git a/datafusion-flight-sql-server/src/state.rs b/datafusion-flight-sql-server/src/state.rs deleted file mode 100644 index 7b17051..0000000 --- a/datafusion-flight-sql-server/src/state.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::fmt::Display; - -use arrow_flight::{ - error::FlightError, - sql::{self, Any, Command}, -}; -use prost::{bytes::Bytes, Message}; - -pub type Result = std::result::Result; - -#[derive(Debug, PartialEq, Clone)] -pub struct CommandTicket { - pub command: sql::Command, -} - -impl CommandTicket { - pub fn new(cmd: sql::Command) -> Self { - Self { command: cmd } - } - - pub fn try_decode(msg: Bytes) -> Result { - let msg = CommandTicketMessage::decode(msg).map_err(decode_error_flight_error)?; - - Self::try_decode_command(msg.command) - } - - pub fn try_decode_command(cmd: Bytes) -> Result { - let content_msg = Any::decode(cmd).map_err(decode_error_flight_error)?; - let command = Command::try_from(content_msg).map_err(FlightError::Arrow)?; - - Ok(Self { command }) - } - - pub fn try_encode(self) -> Result { - let content_msg = self.command.into_any().encode_to_vec(); - - let msg = CommandTicketMessage { - command: content_msg.into(), - }; - - Ok(msg.encode_to_vec().into()) - } -} - -#[derive(Clone, PartialEq, Message)] -struct CommandTicketMessage { - #[prost(bytes = "bytes", tag = "2")] - command: Bytes, -} - -fn decode_error_flight_error(err: prost::DecodeError) -> FlightError { - FlightError::DecodeError(format!("{err:?}")) -} - -/// Represents a query handle for use in prepared statements. -/// All state required to run the prepared statement is passed -/// back and forth to the client, so any service instance can run it -#[derive(Debug, Clone)] -pub struct QueryHandle { - /// The raw SQL query text - query: String, - parameters: Option, -} - -impl QueryHandle { - pub fn new(query: String, parameters: Option) -> Self { - Self { query, parameters } - } - - pub fn query(&self) -> &str { - self.query.as_ref() - } - - pub fn parameters(&self) -> Option<&[u8]> { - self.parameters.as_deref() - } - - pub fn set_parameters(&mut self, parameters: Option) { - self.parameters = parameters; - } - - pub fn try_decode(msg: Bytes) -> Result { - let msg = QueryHandleMessage::decode(msg).map_err(decode_error_flight_error)?; - - Ok(Self { - query: msg.query, - parameters: msg.parameters, - }) - } - - pub fn encode(self) -> Bytes { - let msg = QueryHandleMessage { - query: self.query, - parameters: self.parameters, - }; - - msg.encode_to_vec().into() - } -} - -impl From for Bytes { - fn from(value: QueryHandle) -> Self { - value.encode() - } -} - -impl Display for QueryHandle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Query({})", self.query) - } -} - -#[derive(Clone, PartialEq, Message)] -pub struct QueryHandleMessage { - /// The raw SQL query text - #[prost(string, tag = "1")] - query: String, - #[prost(bytes = "bytes", optional, tag = "2")] - parameters: Option, -} diff --git a/datafusion-flight-sql-table-provider/Cargo.toml b/datafusion-flight-sql-table-provider/Cargo.toml deleted file mode 100644 index 79a1311..0000000 --- a/datafusion-flight-sql-table-provider/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "datafusion-flight-sql-table-provider" -version.workspace = true -edition.workspace = true -license.workspace = true -readme.workspace = true - -[dependencies] -arrow-flight.workspace = true -arrow.workspace = true -async-trait.workspace = true -datafusion-federation = { workspace = true, features = ["sql"] } -datafusion.workspace = true -futures.workspace = true -tonic.workspace = true diff --git a/datafusion-flight-sql-table-provider/src/lib.rs b/datafusion-flight-sql-table-provider/src/lib.rs deleted file mode 100644 index 0bcf432..0000000 --- a/datafusion-flight-sql-table-provider/src/lib.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::sync::Arc; - -use arrow::{datatypes::SchemaRef, error::ArrowError}; -use arrow_flight::sql::client::FlightSqlServiceClient; -use async_trait::async_trait; -use datafusion::{ - error::{DataFusionError, Result}, - physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, - sql::unparser::dialect::{DefaultDialect, Dialect}, -}; -use datafusion_federation::sql::SQLExecutor; -use futures::TryStreamExt; -use tonic::transport::Channel; - -pub struct FlightSQLExecutor { - context: String, - client: FlightSqlServiceClient, -} - -impl FlightSQLExecutor { - pub fn new(dsn: String, client: FlightSqlServiceClient) -> Self { - Self { - context: dsn, - client, - } - } - - pub fn context(&mut self, context: String) { - self.context = context; - } -} - -async fn make_flight_sql_stream( - sql: String, - mut client: FlightSqlServiceClient, - schema: SchemaRef, -) -> Result { - let flight_info = client - .execute(sql.to_string(), None) - .await - .map_err(arrow_error_to_df)?; - - let mut flight_data_streams = Vec::with_capacity(flight_info.endpoint.len()); - for endpoint in flight_info.endpoint { - let ticket = endpoint.ticket.ok_or(DataFusionError::Execution( - "FlightEndpoint missing ticket!".to_string(), - ))?; - let flight_data = client.do_get(ticket).await?; - flight_data_streams.push(flight_data); - } - - let record_batch_stream = futures::stream::select_all(flight_data_streams) - .map_err(|e| DataFusionError::External(Box::new(e))); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - schema, - record_batch_stream, - ))) -} - -#[async_trait] -impl SQLExecutor for FlightSQLExecutor { - fn name(&self) -> &str { - "flight_sql_executor" - } - fn compute_context(&self) -> Option { - Some(self.context.clone()) - } - fn execute(&self, sql: &str, schema: SchemaRef) -> Result { - let future_stream = - make_flight_sql_stream(sql.to_string(), self.client.clone(), Arc::clone(&schema)); - let stream = futures::stream::once(future_stream).try_flatten(); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - stream, - ))) - } - - async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "flight_sql source: table inference not implemented".to_string(), - )) - } - - async fn get_table_schema(&self, table_name: &str) -> Result { - let sql = format!("select * from {table_name} limit 1"); - let flight_info = self - .client - .clone() - .execute(sql, None) - .await - .map_err(arrow_error_to_df)?; - let schema = flight_info.try_decode_schema().map_err(arrow_error_to_df)?; - Ok(Arc::new(schema)) - } - - fn dialect(&self) -> Arc { - Arc::new(DefaultDialect {}) - } -} - -fn arrow_error_to_df(err: ArrowError) -> DataFusionError { - DataFusionError::External(format!("arrow error: {err:?}").into()) -}