From 724b36d6a21007c0334405de0a07bd1c6b45edda Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Fri, 23 Sep 2022 14:10:31 -0600 Subject: [PATCH 1/2] Store sessions so users can register tables and query them through flight --- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/core/proto/ballista.proto | 2 + ballista/rust/core/src/client.rs | 4 + .../src/execution_plans/distributed_query.rs | 11 +- .../src/execution_plans/shuffle_reader.rs | 11 +- ballista/rust/core/src/serde/mod.rs | 14 + .../core/src/serde/scheduler/from_proto.rs | 2 + ballista/rust/core/src/serde/scheduler/mod.rs | 2 + .../rust/core/src/serde/scheduler/to_proto.rs | 4 + ballista/rust/executor/src/flight_service.rs | 22 +- ballista/rust/scheduler/src/flight_sql.rs | 243 ++++++++++++++++-- 11 files changed, 280 insertions(+), 37 deletions(-) diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 4561c1e35..2f0d167c2 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -38,7 +38,7 @@ simd = ["datafusion/simd"] [dependencies] ahash = { version = "0.8", default-features = false } -arrow-flight = { version = "22.0.0" } +arrow-flight = { version = "22.0.0", features = ["flight-sql-experimental"] } async-trait = "0.1.41" chrono = { version = "0.4", default-features = false } clap = { version = "3", features = ["derive", "cargo"] } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index d998cc7b0..a2b5f1fd8 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -517,6 +517,8 @@ message FetchPartition { uint32 stage_id = 2; uint32 partition_id = 3; string path = 4; + string host = 5; + uint32 port = 6; } // Mapping from partition id to executor id diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs index 61c19c643..13276f1b2 100644 --- a/ballista/rust/core/src/client.rs +++ b/ballista/rust/core/src/client.rs @@ -80,12 +80,16 @@ impl BallistaClient { stage_id: usize, partition_id: usize, path: &str, + host: &str, + port: u16, ) -> Result { let action = Action::FetchPartition { job_id: job_id.to_string(), stage_id, partition_id, path: path.to_owned(), + host: host.to_string(), + port, }; self.execute_action(&action).await } diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index e9d852817..67393c20b 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -317,16 +317,19 @@ async fn fetch_partition( let partition_id = location.partition_id.ok_or_else(|| { DataFusionError::Internal("Received empty partition id".to_owned()) })?; - let mut ballista_client = - BallistaClient::try_new(metadata.host.as_str(), metadata.port as u16) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + let host = metadata.host.as_str(); + let port = metadata.port as u16; + let mut ballista_client = BallistaClient::try_new(host, port) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; ballista_client .fetch_partition( &partition_id.job_id, partition_id.stage_id as usize, partition_id.partition_id as usize, &location.path, + host, + port, ) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e))) diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index 17609c77d..0c153d3e1 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -185,16 +185,19 @@ async fn fetch_partition( let partition_id = &location.partition_id; // TODO for shuffle client connections, we should avoid creating new connections again and again. // And we should also avoid to keep alive too many connections for long time. - let mut ballista_client = - BallistaClient::try_new(metadata.host.as_str(), metadata.port as u16) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + let host = metadata.host.as_str(); + let port = metadata.port as u16; + let mut ballista_client = BallistaClient::try_new(host, port) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; ballista_client .fetch_partition( &partition_id.job_id, partition_id.stage_id as usize, partition_id.partition_id as usize, &location.path, + host, + port, ) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e))) diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 1e3be74b3..4553c2f48 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -19,6 +19,7 @@ //! as convenience code for interacting with the generated code. use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; +use arrow_flight::sql::ProstMessageExt; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::{FunctionRegistry, Operator}; use datafusion::physical_plan::join_utils::JoinSide; @@ -39,6 +40,19 @@ pub mod generated; pub mod physical_plan; pub mod scheduler; +impl ProstMessageExt for protobuf::Action { + fn type_url() -> &'static str { + "type.googleapis.com/arrow.flight.protocol.sql.Action" + } + + fn as_any(&self) -> prost_types::Any { + prost_types::Any { + type_url: protobuf::Action::type_url().to_string(), + value: self.encode_to_vec(), + } + } +} + pub fn decode_protobuf(bytes: &[u8]) -> Result { let mut buf = Cursor::new(bytes); diff --git a/ballista/rust/core/src/serde/scheduler/from_proto.rs b/ballista/rust/core/src/serde/scheduler/from_proto.rs index 536e2a5c9..cfe0cbbf2 100644 --- a/ballista/rust/core/src/serde/scheduler/from_proto.rs +++ b/ballista/rust/core/src/serde/scheduler/from_proto.rs @@ -44,6 +44,8 @@ impl TryInto for protobuf::Action { stage_id: fetch.stage_id as usize, partition_id: fetch.partition_id as usize, path: fetch.path, + host: fetch.host, + port: fetch.port as u16, }), _ => Err(BallistaError::General( "scheduler::from_proto(Action) invalid or missing action".to_owned(), diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs index c1fa78a23..7a710f494 100644 --- a/ballista/rust/core/src/serde/scheduler/mod.rs +++ b/ballista/rust/core/src/serde/scheduler/mod.rs @@ -40,6 +40,8 @@ pub enum Action { stage_id: usize, partition_id: usize, path: String, + host: String, + port: u16, }, } diff --git a/ballista/rust/core/src/serde/scheduler/to_proto.rs b/ballista/rust/core/src/serde/scheduler/to_proto.rs index 5a7ad3948..0c43b5331 100644 --- a/ballista/rust/core/src/serde/scheduler/to_proto.rs +++ b/ballista/rust/core/src/serde/scheduler/to_proto.rs @@ -41,12 +41,16 @@ impl TryInto for Action { stage_id, partition_id, path, + host, + port, } => Ok(protobuf::Action { action_type: Some(ActionType::FetchPartition(protobuf::FetchPartition { job_id, stage_id: stage_id as u32, partition_id: partition_id as u32, path, + host, + port: port as u32, })), settings: vec![], }), diff --git a/ballista/rust/executor/src/flight_service.rs b/ballista/rust/executor/src/flight_service.rs index 2c25107d0..82c4f0ae9 100644 --- a/ballista/rust/executor/src/flight_service.rs +++ b/ballista/rust/executor/src/flight_service.rs @@ -17,6 +17,7 @@ //! Implementation of the Apache Arrow Flight protocol that wraps an executor. +use std::convert::TryFrom; use std::fs::File; use std::pin::Pin; @@ -35,7 +36,7 @@ use datafusion::arrow::{ record_batch::RecordBatch, }; use futures::{Stream, StreamExt}; -use log::{debug, warn}; +use log::{debug, info, warn}; use std::io::{Read, Seek}; use tokio::sync::mpsc::channel; use tokio::{ @@ -43,6 +44,7 @@ use tokio::{ task, }; use tokio_stream::wrappers::ReceiverStream; +use tonic::metadata::MetadataValue; use tonic::{Request, Response, Status, Streaming}; type FlightDataSender = Sender>; @@ -135,7 +137,23 @@ impl FlightService for BallistaFlightService { &self, _request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("handshake")) + let token = uuid::Uuid::new_v4(); + info!("do_handshake token={}", token); + + let result = HandshakeResponse { + protocol_version: 0, + payload: token.as_bytes().to_vec(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + let str = format!("Bearer {}", token); + let mut resp: Response< + Pin> + Sync + Send>>, + > = Response::new(Box::pin(output)); + let md = MetadataValue::try_from(str) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + resp.metadata_mut().insert("authorization", md); + Ok(resp) } async fn list_flights( diff --git a/ballista/rust/scheduler/src/flight_sql.rs b/ballista/rust/scheduler/src/flight_sql.rs index 2b1c3fcda..a6218c48e 100644 --- a/ballista/rust/scheduler/src/flight_sql.rs +++ b/ballista/rust/scheduler/src/flight_sql.rs @@ -24,25 +24,33 @@ use arrow_flight::sql::{ CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementUpdate, SqlInfo, TicketStatementQuery, + CommandStatementUpdate, ProstAnyExt, SqlInfo, TicketStatementQuery, }; use arrow_flight::{ - Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, Location, Ticket, + Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, Location, Ticket, }; use log::{debug, error, warn}; use std::collections::HashMap; +use std::convert::TryFrom; +use std::pin::Pin; +use std::str::FromStr; use std::sync::{Arc, Mutex}; use std::time::Duration; use tonic::{Request, Response, Status, Streaming}; use crate::scheduler_server::SchedulerServer; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::sql::ProstMessageExt; use arrow_flight::SchemaAsIpc; use ballista_core::config::BallistaConfig; use ballista_core::serde::protobuf; +use ballista_core::serde::protobuf::action::ActionType::FetchPartition; use ballista_core::serde::protobuf::job_status; use ballista_core::serde::protobuf::CompletedJob; use ballista_core::serde::protobuf::JobStatus; use ballista_core::serde::protobuf::PhysicalPlanNode; +use ballista_core::utils::create_grpc_client_connection; use datafusion::arrow; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::ipc::writer::{IpcDataGenerator, IpcWriteOptions}; @@ -52,11 +60,14 @@ use datafusion::prelude::SessionContext; use datafusion_proto::protobuf::LogicalPlanNode; use prost::Message; use tokio::time::sleep; +use tonic::codegen::futures_core::Stream; +use tonic::metadata::MetadataValue; use uuid::Uuid; pub struct FlightSqlServiceImpl { server: SchedulerServer, statements: Arc>>, + contexts: Arc>>>, } impl FlightSqlServiceImpl { @@ -64,10 +75,11 @@ impl FlightSqlServiceImpl { Self { server, statements: Arc::new(Mutex::new(HashMap::new())), + contexts: Arc::new(Mutex::new(HashMap::new())), } } - async fn create_ctx(&self) -> Result, Status> { + async fn create_ctx(&self) -> Result { let config_builder = BallistaConfig::builder(); let config = config_builder .build() @@ -81,7 +93,45 @@ impl FlightSqlServiceImpl { .map_err(|e| { Status::internal(format!("Failed to create SessionContext: {:?}", e)) })?; - Ok(ctx) + let handle = Uuid::new_v4(); + let mut contexts = self + .contexts + .try_lock() + .map_err(|e| Status::internal(format!("Error locking contexts: {}", e)))?; + contexts.insert(handle.clone(), ctx); + Ok(handle) + } + + fn get_ctx(&self, req: &Request) -> Result, Status> { + let auth = req + .metadata() + .get("authorization") + .ok_or(Status::internal("No authorization header!"))?; + let str = auth + .to_str() + .map_err(|e| Status::internal(format!("Error parsing header: {}", e)))?; + let authorization = str.to_string(); + let bearer = "Bearer "; + if !authorization.starts_with(bearer) { + Err(Status::internal(format!("Invalid auth header!")))?; + } + let auth = authorization[bearer.len()..].to_string(); + + let handle = Uuid::from_str(auth.as_str()) + .map_err(|e| Status::internal(format!("Error locking contexts: {}", e)))?; + let contexts = self + .contexts + .try_lock() + .map_err(|e| Status::internal(format!("Error locking contexts: {}", e)))?; + let context = if let Some(context) = contexts.get(&handle) { + context + } else { + Err(Status::internal(format!( + "Context handle not found: {}", + handle + )))? + }; + Ok(context.clone()) } async fn prepare_statement( @@ -146,12 +196,21 @@ impl FlightSqlServiceImpl { ) -> Result, Status> { let mut fieps: Vec<_> = vec![]; for loc in completed.partition_location.iter() { + let (host, port) = if let Some(ref md) = loc.executor_meta { + (md.host.clone(), md.port) + } else { + Err(Status::internal( + "Invalid partition location, missing executor metadata".to_string(), + ))? + }; let fetch = if let Some(ref id) = loc.partition_id { let fetch = protobuf::FetchPartition { job_id: id.job_id.clone(), stage_id: id.stage_id, partition_id: id.partition_id, path: loc.path.clone(), + host: host.clone(), + port, }; protobuf::Action { action_type: Some(protobuf::action::ActionType::FetchPartition( @@ -162,23 +221,17 @@ impl FlightSqlServiceImpl { } else { Err(Status::internal("Error getting partition ID".to_string()))? }; - let authority = if let Some(ref md) = loc.executor_meta { - format!("{}:{}", md.host, md.port) - } else { - Err(Status::internal( - "Invalid partition location, missing executor metadata".to_string(), - ))? - }; if let Some(ref stats) = loc.partition_stats { *num_rows += stats.num_rows; *num_bytes += stats.num_bytes; } else { Err(Status::internal("Error getting stats".to_string()))? } + let authority = format!("{}:{}", &host, &port); // TODO: my host & port let loc = Location { uri: format!("grpc+tcp://{}", authority), }; - let buf = fetch.encode_to_vec(); + let buf = fetch.as_any().encode_to_vec(); let ticket = Ticket { ticket: buf }; let fiep = FlightEndpoint { ticket: Some(ticket), @@ -309,33 +362,140 @@ impl FlightSqlServiceImpl { impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + debug!("do_handshake"); + for md in request.metadata().iter() { + debug!("{:?}", md); + } + + let basic = "Basic "; + let authorization = request + .metadata() + .get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + + let token = self.create_ctx().await?; + + let result = HandshakeResponse { + protocol_version: 0, + payload: token.as_bytes().to_vec(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + let str = format!("Bearer {}", token.to_string()); + let mut resp: Response> + Send>>> = + Response::new(Box::pin(output)); + let md = MetadataValue::try_from(str) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + resp.metadata_mut().insert("authorization", md); + Ok(resp) + } + + async fn do_get_fallback( + &self, + _request: Request, + message: prost_types::Any, + ) -> Result::DoGetStream>, Status> { + println!("type_url: {}", message.type_url); + if message.is::() { + println!("got action!"); + let action: protobuf::Action = message + .unpack() + .map_err(|e| Status::internal(format!("{:?}", e)))? + .ok_or(Status::internal("Expected an Action but got None!"))?; + println!("action={:?}", action); + let (host, port) = match &action.action_type { + Some(FetchPartition(fp)) => (fp.host.clone(), fp.port), + None => Err(Status::internal("Expected an ActionType but got None!"))?, + }; + + let addr = format!("http://{}:{}", host, port); + println!("BallistaClient connecting to {}", addr); + let connection = + create_grpc_client_connection(addr.clone()) + .await + .map_err(|e| { + Status::internal(format!( + "Error connecting to Ballista scheduler or executor at {}: {:?}", + addr, e + )) + })?; + let mut flight_client = FlightServiceClient::new(connection); + let buf = action.encode_to_vec(); + let request = Request::new(Ticket { ticket: buf }); + + let stream = flight_client + .do_get(request) + .await + .map_err(|e| Status::internal(format!("{:?}", e)))? + .into_inner(); + return Ok(Response::new(Box::pin(stream))); + } + + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", + message.type_url + ))) + } + async fn get_flight_info_statement( &self, query: CommandStatementQuery, - _request: Request, + request: Request, ) -> Result, Status> { - debug!("Got query:\n{}", query.query); + debug!("get_flight_info_statement query:\n{}", query.query); - let ctx = self.create_ctx().await?; + let ctx = self.get_ctx(&request)?; let plan = Self::prepare_statement(&query.query, &ctx).await?; let resp = self.execute_plan(ctx, &plan).await?; - debug!("Responding to query..."); + debug!("Returning flight info..."); Ok(resp) } async fn get_flight_info_prepared_statement( &self, handle: CommandPreparedStatementQuery, - _request: Request, + request: Request, ) -> Result, Status> { - let ctx = self.create_ctx().await?; + debug!("get_flight_info_prepared_statement"); + let ctx = self.get_ctx(&request)?; let handle = Uuid::from_slice(handle.prepared_statement_handle.as_slice()) .map_err(|e| Status::internal(format!("Error decoding handle: {}", e)))?; let plan = self.get_plan(&handle)?; let resp = self.execute_plan(ctx, &plan).await?; - debug!("Responding to query..."); + debug!("Responding to query {}...", handle); Ok(resp) } @@ -344,6 +504,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetCatalogs, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_catalogs"); Err(Status::unimplemented("Implement get_flight_info_catalogs")) } async fn get_flight_info_schemas( @@ -351,6 +512,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetDbSchemas, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_schemas"); Err(Status::unimplemented("Implement get_flight_info_schemas")) } async fn get_flight_info_tables( @@ -358,6 +520,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetTables, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_tables"); Err(Status::unimplemented("Implement get_flight_info_tables")) } async fn get_flight_info_table_types( @@ -365,6 +528,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetTableTypes, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_table_types"); Err(Status::unimplemented( "Implement get_flight_info_table_types", )) @@ -374,6 +538,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetSqlInfo, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_sql_info"); // TODO: implement for FlightSQL JDBC to work Err(Status::unimplemented("Implement CommandGetSqlInfo")) } @@ -382,6 +547,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetPrimaryKeys, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_primary_keys"); Err(Status::unimplemented( "Implement get_flight_info_primary_keys", )) @@ -391,6 +557,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetExportedKeys, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_exported_keys"); Err(Status::unimplemented( "Implement get_flight_info_exported_keys", )) @@ -400,6 +567,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetImportedKeys, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_imported_keys"); Err(Status::unimplemented( "Implement get_flight_info_imported_keys", )) @@ -409,6 +577,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetCrossReference, _request: Request, ) -> Result, Status> { + debug!("get_flight_info_cross_reference"); Err(Status::unimplemented( "Implement get_flight_info_cross_reference", )) @@ -419,6 +588,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _ticket: TicketStatementQuery, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_statement"); // let handle = Uuid::from_slice(&ticket.statement_handle) // .map_err(|e| Status::internal(format!("Error decoding ticket: {}", e)))?; // let statements = self.statements.try_lock() @@ -432,6 +602,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandPreparedStatementQuery, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_prepared_statement"); Err(Status::unimplemented("Implement do_get_prepared_statement")) } async fn do_get_catalogs( @@ -439,6 +610,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetCatalogs, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_catalogs"); Err(Status::unimplemented("Implement do_get_catalogs")) } async fn do_get_schemas( @@ -446,6 +618,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetDbSchemas, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_schemas"); Err(Status::unimplemented("Implement do_get_schemas")) } async fn do_get_tables( @@ -453,6 +626,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetTables, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_tables"); Err(Status::unimplemented("Implement do_get_tables")) } async fn do_get_table_types( @@ -460,6 +634,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetTableTypes, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_table_types"); Err(Status::unimplemented("Implement do_get_table_types")) } async fn do_get_sql_info( @@ -467,6 +642,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetSqlInfo, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_sql_info"); Err(Status::unimplemented("Implement do_get_sql_info")) } async fn do_get_primary_keys( @@ -474,6 +650,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetPrimaryKeys, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_primary_keys"); Err(Status::unimplemented("Implement do_get_primary_keys")) } async fn do_get_exported_keys( @@ -481,6 +658,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetExportedKeys, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_exported_keys"); Err(Status::unimplemented("Implement do_get_exported_keys")) } async fn do_get_imported_keys( @@ -488,6 +666,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetImportedKeys, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_imported_keys"); Err(Status::unimplemented("Implement do_get_imported_keys")) } async fn do_get_cross_reference( @@ -495,6 +674,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandGetCrossReference, _request: Request, ) -> Result::DoGetStream>, Status> { + debug!("do_get_cross_reference"); Err(Status::unimplemented("Implement do_get_cross_reference")) } // do_put @@ -503,6 +683,7 @@ impl FlightSqlService for FlightSqlServiceImpl { _ticket: CommandStatementUpdate, _request: Request>, ) -> Result { + debug!("do_put_statement_update"); Err(Status::unimplemented("Implement do_put_statement_update")) } async fn do_put_prepared_statement_query( @@ -510,29 +691,37 @@ impl FlightSqlService for FlightSqlServiceImpl { _query: CommandPreparedStatementQuery, _request: Request>, ) -> Result::DoPutStream>, Status> { + debug!("do_put_prepared_statement_query"); Err(Status::unimplemented( "Implement do_put_prepared_statement_query", )) } async fn do_put_prepared_statement_update( &self, - _handle: CommandPreparedStatementUpdate, - _request: Request>, + handle: CommandPreparedStatementUpdate, + request: Request>, ) -> Result { - Err(Status::unimplemented( - "Implement do_put_prepared_statement_update", - )) + debug!("do_put_prepared_statement_update"); + let ctx = self.get_ctx(&request)?; + let handle = Uuid::from_slice(handle.prepared_statement_handle.as_slice()) + .map_err(|e| Status::internal(format!("Error decoding handle: {}", e)))?; + let plan = self.get_plan(&handle)?; + let _ = self.execute_plan(ctx, &plan).await?; + debug!("Sending -1 rows affected"); + Ok(-1) } async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, - _request: Request, + request: Request, ) -> Result { - let ctx = self.create_ctx().await?; + debug!("do_action_create_prepared_statement"); + let ctx = self.get_ctx(&request)?; let plan = Self::prepare_statement(&query.query, &ctx).await?; let schema_bytes = self.df_schema_to_arrow(plan.schema())?; let handle = self.cache_plan(plan)?; + debug!("Prepared statement {}:\n{}", handle, query.query); let res = ActionCreatePreparedStatementResult { prepared_statement_handle: handle.as_bytes().to_vec(), dataset_schema: schema_bytes, @@ -546,8 +735,10 @@ impl FlightSqlService for FlightSqlServiceImpl { handle: ActionClosePreparedStatementRequest, _request: Request, ) { + debug!("do_action_close_prepared_statement"); let handle = Uuid::from_slice(handle.prepared_statement_handle.as_slice()); let handle = if let Ok(handle) = handle { + debug!("Closing {}", handle); handle } else { return; From 0c467194f81fe0059bb6422ba5968800973a3532 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Fri, 23 Sep 2022 14:37:45 -0600 Subject: [PATCH 2/2] Store sessions so users can register tables and query them through flight --- ballista/rust/core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 2f0d167c2..2ef4995ac 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -38,7 +38,7 @@ simd = ["datafusion/simd"] [dependencies] ahash = { version = "0.8", default-features = false } -arrow-flight = { version = "22.0.0", features = ["flight-sql-experimental"] } +arrow-flight = { version = "22.0.0", features = ["flight-sql-experimental"] } async-trait = "0.1.41" chrono = { version = "0.4", default-features = false } clap = { version = "3", features = ["derive", "cargo"] }