diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs new file mode 100644 index 000000000000..0e75ac7c0c7f --- /dev/null +++ b/arrow-flight/src/client.rs @@ -0,0 +1,567 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + flight_service_client::FlightServiceClient, utils::flight_data_to_arrow_batch, + FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket, +}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::Schema; +use bytes::Bytes; +use futures::{future::ready, ready, stream, StreamExt}; +use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc, task::Poll}; +use tonic::{metadata::MetadataMap, transport::Channel, Streaming}; + +use crate::error::{FlightError, Result}; + +/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client. +/// +/// [`FlightClient`] is intended as a convenience for interactions +/// with Arrow Flight servers. For more direct control, such as access +/// to the response headers, use [`FlightServiceClient`] directly +/// via methods such as [`Self::inner`] or [`Self::into_inner`]. +/// +/// # Example: +/// ```no_run +/// # async fn run() { +/// # use arrow_flight::FlightClient; +/// # use bytes::Bytes; +/// use tonic::transport::Channel; +/// let channel = Channel::from_static("http://localhost:1234") +/// .connect() +/// .await +/// .expect("error connecting"); +/// +/// let mut client = FlightClient::new(channel); +/// +/// // Send 'Hi' bytes as the handshake request to the server +/// let response = client +/// .handshake(Bytes::from("Hi")) +/// .await +/// .expect("error handshaking"); +/// +/// // Expect the server responded with 'Ho' +/// assert_eq!(response, Bytes::from("Ho")); +/// # } +/// ``` +#[derive(Debug)] +pub struct FlightClient { + /// Optional grpc header metadata to include with each request + metadata: MetadataMap, + + /// The inner client + inner: FlightServiceClient, +} + +impl FlightClient { + /// Creates a client client with the provided [`Channel`](tonic::transport::Channel) + pub fn new(channel: Channel) -> Self { + Self::new_from_inner(FlightServiceClient::new(channel)) + } + + /// Creates a new higher level client with the provided lower level client + pub fn new_from_inner(inner: FlightServiceClient) -> Self { + Self { + metadata: MetadataMap::new(), + inner, + } + } + + /// Return a reference to gRPC metadata included with each request + pub fn metadata(&self) -> &MetadataMap { + &self.metadata + } + + /// Return a reference to gRPC metadata included with each request + /// + /// These headers can be used, for example, to include + /// authorization or other application specific headers. + pub fn metadata_mut(&mut self) -> &mut MetadataMap { + &mut self.metadata + } + + /// Add the specified header with value to all subsequent + /// requests. See [`Self::metadata_mut`] for fine grained control. + pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> { + let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes()) + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + let value = value + .parse() + .map_err(|e| FlightError::ExternalError(Box::new(e)))?; + + // ignore previous value + self.metadata.insert(key, value); + + Ok(()) + } + + /// Return a reference to the underlying tonic + /// [`FlightServiceClient`] + pub fn inner(&self) -> &FlightServiceClient { + &self.inner + } + + /// Return a mutable reference to the underlying tonic + /// [`FlightServiceClient`] + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + &mut self.inner + } + + /// Consume this client and return the underlying tonic + /// [`FlightServiceClient`] + pub fn into_inner(self) -> FlightServiceClient { + self.inner + } + + /// Perform an Arrow Flight handshake with the server, sending + /// `payload` as the [`HandshakeRequest`] payload and returning + /// the [`HandshakeResponse`](crate::HandshakeResponse) + /// bytes returned from the server + /// + /// See [`FlightClient`] docs for an example. + pub async fn handshake(&mut self, payload: impl Into) -> Result { + let request = HandshakeRequest { + protocol_version: 0, + payload: payload.into(), + }; + + // apply headers, etc + let request = self.make_request(stream::once(ready(request))); + + let mut response_stream = self.inner.handshake(request).await?.into_inner(); + + if let Some(response) = response_stream.next().await.transpose()? { + // check if there is another response + if response_stream.next().await.is_some() { + return Err(FlightError::protocol( + "Got unexpected second response from handshake", + )); + } + + Ok(response.payload) + } else { + Err(FlightError::protocol("No response from handshake")) + } + } + + /// Make a `DoGet` call to the server with the provided ticket, + /// returning a [`FlightRecordBatchStream`] for reading + /// [`RecordBatch`]es. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use bytes::Bytes; + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::Ticket; + /// # use arrow_array::RecordBatch; + /// # use tonic::transport::Channel; + /// # use futures::stream::TryStreamExt; + /// # let channel = Channel::from_static("http://localhost:1234") + /// # .connect() + /// # .await + /// # .expect("error connecting"); + /// # let ticket = Ticket { ticket: Bytes::from("foo") }; + /// let mut client = FlightClient::new(channel); + /// + /// // Invoke a do_get request on the server with a previously + /// // received Ticket + /// + /// let response = client + /// .do_get(ticket) + /// .await + /// .expect("error invoking do_get"); + /// + /// // Use try_collect to get the RecordBatches from the server + /// let batches: Vec = response + /// .try_collect() + /// .await + /// .expect("no stream errors"); + /// # } + /// ``` + pub async fn do_get(&mut self, ticket: Ticket) -> Result { + let request = self.make_request(ticket); + + let response = self.inner.do_get(request).await?.into_inner(); + + let flight_data_stream = FlightDataStream::new(response); + Ok(FlightRecordBatchStream::new(flight_data_stream)) + } + + /// Make a `GetFlightInfo` call to the server with the provided + /// [`FlightDescriptor`] and return the [`FlightInfo`] from the + /// server. The [`FlightInfo`] can be used with [`Self::do_get`] + /// to retrieve the requested batches. + /// + /// # Example: + /// ```no_run + /// # async fn run() { + /// # use arrow_flight::FlightClient; + /// # use arrow_flight::FlightDescriptor; + /// # use tonic::transport::Channel; + /// # let channel = Channel::from_static("http://localhost:1234") + /// # .connect() + /// # .await + /// # .expect("error connecting"); + /// let mut client = FlightClient::new(channel); + /// + /// // Send a 'CMD' request to the server + /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec()); + /// let flight_info = client + /// .get_flight_info(request) + /// .await + /// .expect("error handshaking"); + /// + /// // retrieve the first endpoint from the returned flight info + /// let ticket = flight_info + /// .endpoint[0] + /// // Extract the ticket + /// .ticket + /// .clone() + /// .expect("expected ticket"); + /// + /// // Retrieve the corresponding RecordBatch stream with do_get + /// let data = client + /// .do_get(ticket) + /// .await + /// .expect("error fetching data"); + /// # } + /// ``` + pub async fn get_flight_info( + &mut self, + descriptor: FlightDescriptor, + ) -> Result { + let request = self.make_request(descriptor); + + let response = self.inner.get_flight_info(request).await?.into_inner(); + Ok(response) + } + + // TODO other methods + // list_flights + // get_schema + // do_put + // do_action + // list_actions + // do_exchange + + /// return a Request, adding any configured metadata + fn make_request(&self, t: T) -> tonic::Request { + // Pass along metadata + let mut request = tonic::Request::new(t); + *request.metadata_mut() = self.metadata.clone(); + request + } +} + +/// A stream of [`RecordBatch`]es from from an Arrow Flight server. +/// +/// To access the lower level Flight messages directly, consider +/// calling [`Self::into_inner`] and using the [`FlightDataStream`] +/// directly. +#[derive(Debug)] +pub struct FlightRecordBatchStream { + inner: FlightDataStream, + got_schema: bool, +} + +impl FlightRecordBatchStream { + pub fn new(inner: FlightDataStream) -> Self { + Self { + inner, + got_schema: false, + } + } + + /// Has a message defining the schema been received yet? + pub fn got_schema(&self) -> bool { + self.got_schema + } + + /// Consume self and return the wrapped [`FlightDataStream`] + pub fn into_inner(self) -> FlightDataStream { + self.inner + } +} +impl futures::Stream for FlightRecordBatchStream { + type Item = Result; + + /// Returns the next [`RecordBatch`] available in this stream, or `None` if + /// there are no further results available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + let res = ready!(self.inner.poll_next_unpin(cx)); + match res { + // Inner exhausted + None => { + return Poll::Ready(None); + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + // translate data + Some(Ok(data)) => match data.payload { + DecodedPayload::Schema(_) if self.got_schema => { + return Poll::Ready(Some(Err(FlightError::protocol( + "Unexpectedly saw multiple Schema messages in FlightData stream", + )))); + } + DecodedPayload::Schema(_) => { + self.got_schema = true; + // Need next message, poll inner again + } + DecodedPayload::RecordBatch(batch) => { + return Poll::Ready(Some(Ok(batch))); + } + DecodedPayload::None => { + // Need next message + } + }, + } + } + } +} + +/// Wrapper around a stream of [`FlightData`] that handles the details +/// of decoding low level Flight messages into [`Schema`] and +/// [`RecordBatch`]es, including details such as dictionaries. +/// +/// # Protocol Details +/// +/// The client handles flight messages as followes: +/// +/// - **None:** This message has no effect. This is useful to +/// transmit metadata without any actual payload. +/// +/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and +/// the decoded schema is returned. +/// +/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing +/// dictionary for the same column will be overwritten. This +/// message is NOT visible. +/// +/// - **Record Batch:** Record batch is created based on the current +/// schema and dictionaries. This fails if no schema was transmitted +/// yet. +/// +/// All other message types (at the time of writing: e.g. tensor and +/// sparse tensor) lead to an error. +/// +/// Example usecases +/// +/// 1. Using this low level stream it is possible to receive a steam +/// of RecordBatches in FlightData that have different schemas by +/// handling multiple schema messages separately. +#[derive(Debug)] +pub struct FlightDataStream { + /// Underlying data stream + response: Streaming, + /// Decoding state + state: Option, + /// seen the end of the inner stream? + done: bool, +} + +impl FlightDataStream { + /// Create a new wrapper around the stream of FlightData + pub fn new(response: Streaming) -> Self { + Self { + state: None, + response, + done: false, + } + } + + /// Extracts flight data from the next message, updating decoding + /// state as necessary. + fn extract_message(&mut self, data: FlightData) -> Result> { + use arrow_ipc::MessageHeader; + let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| { + FlightError::DecodeError(format!("Error decoding root message: {e}")) + })?; + + match message.header_type() { + MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))), + MessageHeader::Schema => { + let schema = Schema::try_from(&data).map_err(|e| { + FlightError::DecodeError(format!("Error decoding schema: {e}")) + })?; + + let schema = Arc::new(schema); + let dictionaries_by_field = HashMap::new(); + + self.state = Some(FlightStreamState { + schema: Arc::clone(&schema), + dictionaries_by_field, + }); + Ok(Some(DecodedFlightData::new_schema(data, schema))) + } + MessageHeader::DictionaryBatch => { + let state = if let Some(state) = self.state.as_mut() { + state + } else { + return Err(FlightError::protocol( + "Received DictionaryBatch prior to Schema", + )); + }; + + let buffer: arrow_buffer::Buffer = data.data_body.into(); + let dictionary_batch = + message.header_as_dictionary_batch().ok_or_else(|| { + FlightError::protocol( + "Could not get dictionary batch from DictionaryBatch message", + ) + })?; + + arrow_ipc::reader::read_dictionary( + &buffer, + dictionary_batch, + &state.schema, + &mut state.dictionaries_by_field, + &message.version(), + ) + .map_err(|e| { + FlightError::DecodeError(format!( + "Error decoding ipc dictionary: {e}" + )) + })?; + + // Updated internal state, but no decoded message + Ok(None) + } + MessageHeader::RecordBatch => { + let state = if let Some(state) = self.state.as_ref() { + state + } else { + return Err(FlightError::protocol( + "Received RecordBatch prior to Schema", + )); + }; + + let batch = flight_data_to_arrow_batch( + &data, + Arc::clone(&state.schema), + &state.dictionaries_by_field, + ) + .map_err(|e| { + FlightError::DecodeError(format!( + "Error decoding ipc RecordBatch: {e}" + )) + })?; + + Ok(Some(DecodedFlightData::new_record_batch(data, batch))) + } + other => { + let name = other.variant_name().unwrap_or("UNKNOWN"); + Err(FlightError::protocol(format!("Unexpected message: {name}"))) + } + } + } +} + +impl futures::Stream for FlightDataStream { + type Item = Result; + /// Returns the result of decoding the next [`FlightData`] message + /// from the server, or `None` if there are no further results + /// available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if self.done { + return Poll::Ready(None); + } + loop { + let res = ready!(self.response.poll_next_unpin(cx)); + + return Poll::Ready(match res { + None => { + self.done = true; + None // inner is exhausted + } + Some(data) => Some(match data { + Err(e) => Err(FlightError::Tonic(e)), + Ok(data) => match self.extract_message(data) { + Ok(Some(extracted)) => Ok(extracted), + Ok(None) => continue, // Need next input message + Err(e) => Err(e), + }, + }), + }); + } + } +} + +/// tracks the state needed to reconstruct [`RecordBatch`]es from a +/// streaming flight response. +#[derive(Debug)] +struct FlightStreamState { + schema: Arc, + dictionaries_by_field: HashMap, +} + +/// FlightData and the decoded payload (Schema, RecordBatch), if any +#[derive(Debug)] +pub struct DecodedFlightData { + pub inner: FlightData, + pub payload: DecodedPayload, +} + +impl DecodedFlightData { + pub fn new_none(inner: FlightData) -> Self { + Self { + inner, + payload: DecodedPayload::None, + } + } + + pub fn new_schema(inner: FlightData, schema: Arc) -> Self { + Self { + inner, + payload: DecodedPayload::Schema(schema), + } + } + + pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self { + Self { + inner, + payload: DecodedPayload::RecordBatch(batch), + } + } + + /// return the metadata field of the inner flight data + pub fn app_metadata(&self) -> &[u8] { + &self.inner.app_metadata + } +} + +/// The result of decoding [`FlightData`] +#[derive(Debug)] +pub enum DecodedPayload { + /// None (no data was sent in the corresponding FlightData) + None, + + /// A decoded Schema message + Schema(Arc), + + /// A decoded Record batch. + RecordBatch(RecordBatch), +} diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs new file mode 100644 index 000000000000..fbb9efa44c24 --- /dev/null +++ b/arrow-flight/src/error.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Errors for the Apache Arrow Flight crate +#[derive(Debug)] +pub enum FlightError { + /// Returned when functionality is not yet available. + NotYetImplemented(String), + /// Error from the underlying tonic library + Tonic(tonic::Status), + /// Some unexpected message was received + ProtocolError(String), + /// An error occured during decoding + DecodeError(String), + /// Some other (opaque) error + ExternalError(Box), +} + +impl FlightError { + pub fn protocol(message: impl Into) -> Self { + Self::ProtocolError(message.into()) + } + + /// Wraps an external error in an `ArrowError`. + pub fn from_external_error(error: Box) -> Self { + Self::ExternalError(error) + } +} + +impl std::fmt::Display for FlightError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // TODO better format / error + write!(f, "{:?}", self) + } +} + +impl std::error::Error for FlightError {} + +impl From for FlightError { + fn from(status: tonic::Status) -> Self { + Self::Tonic(status) + } +} + +pub type Result = std::result::Result; diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 051509fb16e2..f30cb54844da 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -71,6 +71,13 @@ pub mod flight_service_server { pub use gen::flight_service_server::FlightServiceServer; } +/// Mid Level [`FlightClient`] for +pub mod client; +pub use client::FlightClient; + +/// Common error types +pub mod error; + pub use gen::Action; pub use gen::ActionType; pub use gen::BasicAuth; diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs new file mode 100644 index 000000000000..5bc1062f046d --- /dev/null +++ b/arrow-flight/tests/client.rs @@ -0,0 +1,309 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration test for "mid level" Client + +mod common { + pub mod server; +} +use arrow_flight::{ + error::FlightError, FlightClient, FlightDescriptor, FlightInfo, HandshakeRequest, + HandshakeResponse, +}; +use bytes::Bytes; +use common::server::TestFlightServer; +use futures::Future; +use tokio::{net::TcpListener, task::JoinHandle}; +use tonic::{ + transport::{Channel, Uri}, + Status, +}; + +use std::{net::SocketAddr, time::Duration}; + +const DEFAULT_TIMEOUT_SECONDS: u64 = 30; + +#[tokio::test] +async fn test_handshake() { + do_test(|test_server, mut client| async move { + let request_payload = Bytes::from("foo"); + let response_payload = Bytes::from("Bar"); + + let request = HandshakeRequest { + payload: request_payload.clone(), + protocol_version: 0, + }; + + let response = HandshakeResponse { + payload: response_payload.clone(), + protocol_version: 0, + }; + + test_server.set_handshake_response(Ok(response)); + let response = client.handshake(request_payload).await.unwrap(); + assert_eq!(response, response_payload); + assert_eq!(test_server.take_handshake_request(), Some(request)); + }) + .await; +} + +#[tokio::test] +async fn test_handshake_error() { + do_test(|test_server, mut client| async move { + let request_payload = "foo".to_string().into_bytes(); + let e = Status::unauthenticated("DENIED"); + test_server.set_handshake_response(Err(e)); + + let response = client.handshake(request_payload).await.unwrap_err(); + let e = Status::unauthenticated("DENIED"); + expect_status(response, e); + }) + .await; +} + +#[tokio::test] +async fn test_handshake_metadata() { + do_test(|test_server, mut client| async move { + client.add_header("foo", "bar").unwrap(); + + let request_payload = Bytes::from("Blarg"); + let response_payload = Bytes::from("Bazz"); + + let response = HandshakeResponse { + payload: response_payload.clone(), + protocol_version: 0, + }; + + test_server.set_handshake_response(Ok(response)); + client.handshake(request_payload).await.unwrap(); + ensure_metadata(&client, &test_server); + }) + .await; +} + +/// Verifies that all headers sent from the the client are in the request_metadata +fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) { + let client_metadata = client.metadata().clone().into_headers(); + assert!(!client_metadata.is_empty()); + let metadata = test_server + .take_last_request_metadata() + .expect("No headers in server") + .into_headers(); + + for (k, v) in &client_metadata { + assert_eq!( + metadata.get(k).as_ref(), + Some(&v), + "Missing / Mismatched metadata {:?} sent {:?} got {:?}", + k, + client_metadata, + metadata + ); + } +} + +fn test_flight_info(request: &FlightDescriptor) -> FlightInfo { + FlightInfo { + schema: Bytes::new(), + endpoint: vec![], + flight_descriptor: Some(request.clone()), + total_bytes: 123, + total_records: 456, + } +} + +#[tokio::test] +async fn test_get_flight_info() { + do_test(|test_server, mut client| async move { + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let expected_response = test_flight_info(&request); + test_server.set_get_flight_info_response(Ok(expected_response.clone())); + + let response = client.get_flight_info(request.clone()).await.unwrap(); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_get_flight_info_request(), Some(request)); + }) + .await; +} + +#[tokio::test] +async fn test_get_flight_info_error() { + do_test(|test_server, mut client| async move { + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let e = Status::unauthenticated("DENIED"); + test_server.set_get_flight_info_response(Err(e)); + + let response = client.get_flight_info(request.clone()).await.unwrap_err(); + let e = Status::unauthenticated("DENIED"); + expect_status(response, e); + }) + .await; +} + +#[tokio::test] +async fn test_get_flight_info_metadata() { + do_test(|test_server, mut client| async move { + client.add_header("foo", "bar").unwrap(); + let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); + + let expected_response = test_flight_info(&request); + test_server.set_get_flight_info_response(Ok(expected_response)); + client.get_flight_info(request.clone()).await.unwrap(); + ensure_metadata(&client, &test_server); + }) + .await; +} + +// TODO more negative tests (like if there are endpoints defined, etc) + +// TODO test for do_get + +/// Runs the future returned by the function, passing it a test server and client +async fn do_test(f: F) +where + F: Fn(TestFlightServer, FlightClient) -> Fut, + Fut: Future, +{ + let test_server = TestFlightServer::new(); + let fixture = TestFixture::new(&test_server).await; + let client = FlightClient::new(fixture.channel().await); + + // run the test function + f(test_server, client).await; + + // cleanly shutdown the test fixture + fixture.shutdown_and_wait().await +} + +fn expect_status(error: FlightError, expected: Status) { + let status = if let FlightError::Tonic(status) = error { + status + } else { + panic!("Expected FlightError::Tonic, got: {:?}", error); + }; + + assert_eq!( + status.code(), + expected.code(), + "Got {:?} want {:?}", + status, + expected + ); + assert_eq!( + status.message(), + expected.message(), + "Got {:?} want {:?}", + status, + expected + ); + assert_eq!( + status.details(), + expected.details(), + "Got {:?} want {:?}", + status, + expected + ); +} + +/// Creates and manages a running TestServer with a background task +struct TestFixture { + /// channel to send shutdown command + shutdown: Option>, + + /// Address the server is listening on + addr: SocketAddr, + + // handle for the server task + handle: Option>>, +} + +impl TestFixture { + /// create a new test fixture from the server + pub async fn new(test_server: &TestFlightServer) -> Self { + // let OS choose a a free port + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + println!("Listening on {addr}"); + + // prepare the shutdown channel + let (tx, rx) = tokio::sync::oneshot::channel(); + + let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); + + let shutdown_future = async move { + rx.await.ok(); + }; + + let serve_future = tonic::transport::Server::builder() + .timeout(server_timeout) + .add_service(test_server.service()) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + + // Run the server in its own background task + let handle = tokio::task::spawn(serve_future); + + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + /// Return a [`Channel`] connected to the TestServer + pub async fn channel(&self) -> Channel { + let url = format!("http://{}", self.addr); + let uri: Uri = url.parse().expect("Valid URI"); + Channel::builder(uri) + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS)) + .connect() + .await + .expect("error connecting to server") + } + + /// Stops the test server and waits for the server to shutdown + pub async fn shutdown_and_wait(mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + println!("Waiting on server to finish"); + handle + .await + .expect("task join error (panic?)") + .expect("Server Error found at shutdown"); + } + } +} + +impl Drop for TestFixture { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + shutdown.send(()).ok(); + } + if self.handle.is_some() { + // tests should properly clean up TestFixture + println!("TestFixture::Drop called prior to `shutdown_and_wait`"); + } + } +} diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs new file mode 100644 index 000000000000..f1cb140b68c7 --- /dev/null +++ b/arrow-flight/tests/common/server.rs @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, Mutex}; + +use futures::stream::BoxStream; +use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; + +use arrow_flight::{ + flight_service_server::{FlightService, FlightServiceServer}, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, +}; + +#[derive(Debug, Clone)] +/// Flight server for testing, with configurable responses +pub struct TestFlightServer { + /// Shared state to configure responses + state: Arc>, +} + +impl TestFlightServer { + /// Create a `TestFlightServer` + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(State::new())), + } + } + + /// Return an [`FlightServiceServer`] that can be used with a + /// [`Server`](tonic::transport::Server) + pub fn service(&self) -> FlightServiceServer { + // wrap up tonic goop + FlightServiceServer::new(self.clone()) + } + + /// Specify the response returned from the next call to handshake + pub fn set_handshake_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.handshake_response.replace(response); + } + + /// Take and return last handshake request send to the server, + pub fn take_handshake_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .handshake_request + .take() + } + + /// Specify the response returned from the next call to handshake + pub fn set_get_flight_info_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.get_flight_info_response.replace(response); + } + + /// Take and return last get_flight_info request send to the server, + pub fn take_get_flight_info_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .get_flight_info_request + .take() + } + + /// Returns the last metadata from a request received by the server + pub fn take_last_request_metadata(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .last_request_metadata + .take() + } + + /// Save the last request's metadatacom + fn save_metadata(&self, request: &Request) { + let metadata = request.metadata().clone(); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.last_request_metadata = Some(metadata); + } +} + +/// mutable state for the TestFlightSwrver +#[derive(Debug, Default)] +struct State { + /// The last handshake request that was received + pub handshake_request: Option, + /// The next response to return from `handshake()` + pub handshake_response: Option>, + /// The last `get_flight_info` request received + pub get_flight_info_request: Option, + /// the next response to return from `get_flight_info` + pub get_flight_info_response: Option>, + /// The last request headers received + pub last_request_metadata: Option, +} + +impl State { + fn new() -> Self { + Default::default() + } +} + +/// Implement the FlightService trait +#[tonic::async_trait] +impl FlightService for TestFlightServer { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + self.save_metadata(&request); + let handshake_request = request.into_inner().message().await?.unwrap(); + + let mut state = self.state.lock().expect("mutex not poisoned"); + state.handshake_request = Some(handshake_request); + + let response = state.handshake_response.take().unwrap_or_else(|| { + Err(Status::internal("No handshake response configured")) + })?; + + // turn into a streaming response + let output = futures::stream::iter(std::iter::once(Ok(response))); + Ok(Response::new(Box::pin(output) as Self::HandshakeStream)) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Implement list_flights")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_flight_info_request = Some(request.into_inner()); + let response = state.get_flight_info_response.take().unwrap_or_else(|| { + Err(Status::internal("No get_flight_info response configured")) + })?; + Ok(Response::new(response)) + } + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Implement get_schema")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Implement do_get")) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Implement do_put")) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Implement do_action")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Implement list_actions")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Implement do_exchange")) + } +}