diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 87aeba1c1194..3057735a6ad7 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -454,6 +454,13 @@ impl Action { } } +impl Result { + /// Create a new Result with the specified body + pub fn new(body: impl Into) -> Self { + Self { body: body.into() } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 7537e46db403..032dad04923d 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -23,9 +23,10 @@ mod common { use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, - error::FlightError, FlightClient, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, Ticket, + error::FlightError, Action, ActionType, Criteria, Empty, FlightClient, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, Ticket, }; +use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; use common::server::TestFlightServer; use futures::{Future, StreamExt, TryStreamExt}; @@ -70,10 +71,9 @@ async fn test_handshake_error() { do_test(|test_server, mut client| async move { let request_payload = "foo-request-payload".to_string().into_bytes(); let e = Status::unauthenticated("DENIED"); - test_server.set_handshake_response(Err(e)); + test_server.set_handshake_response(Err(e.clone())); let response = client.handshake(request_payload).await.unwrap_err(); - let e = Status::unauthenticated("DENIED"); expect_status(response, e); }) .await; @@ -134,10 +134,9 @@ async fn test_get_flight_info_error() { let request = FlightDescriptor::new_cmd(b"My Command".to_vec()); let e = Status::unauthenticated("DENIED"); - test_server.set_get_flight_info_response(Err(e)); + test_server.set_get_flight_info_response(Err(e.clone())); let response = client.get_flight_info(request.clone()).await.unwrap_err(); - let e = Status::unauthenticated("DENIED"); expect_status(response, e); }) .await; @@ -213,7 +212,7 @@ async fn test_do_get_error_in_record_batch_stream() { let e = Status::data_loss("she's dead jim"); - let expected_response = vec![Ok(batch), Err(FlightError::Tonic(e.clone()))]; + let expected_response = vec![Ok(batch), Err(e.clone())]; test_server.set_do_get_response(expected_response); @@ -300,11 +299,13 @@ async fn test_do_put_error_stream() { let input_flight_data = test_flight_data().await; + let e = Status::invalid_argument("bad arg"); + let response = vec![ Ok(PutResult { app_metadata: Bytes::from("foo-metadata"), }), - Err(FlightError::Tonic(Status::invalid_argument("bad arg"))), + Err(e.clone()), ]; test_server.set_do_put_response(response); @@ -320,7 +321,6 @@ async fn test_do_put_error_stream() { Err(e) => e, }; - let e = Status::invalid_argument("bad arg"); expect_status(response, e); // server still got the request assert_eq!(test_server.take_do_put_request(), Some(input_flight_data)); @@ -404,6 +404,7 @@ async fn test_do_exchange_error_stream() { let input_flight_data = test_flight_data().await; + let e = Status::invalid_argument("the error"); let response = test_flight_data2() .await .into_iter() @@ -413,8 +414,7 @@ async fn test_do_exchange_error_stream() { Ok(m) } else { // make all messages after the first an error - let e = tonic::Status::invalid_argument("the error"); - Err(FlightError::Tonic(e)) + Err(e.clone()) } }) .collect(); @@ -432,7 +432,6 @@ async fn test_do_exchange_error_stream() { Err(e) => e, }; - let e = tonic::Status::invalid_argument("the error"); expect_status(response, e); // server still got the request assert_eq!( @@ -444,6 +443,309 @@ async fn test_do_exchange_error_stream() { .await; } +#[tokio::test] +async fn test_get_schema() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let schema = Schema::new(vec![Field::new("foo", DataType::Int64, true)]); + + let request = FlightDescriptor::new_cmd("my command"); + test_server.set_get_schema_response(Ok(schema.clone())); + + let response = client + .get_schema(request.clone()) + .await + .expect("error making request"); + + let expected_schema = schema; + let expected_request = request; + + assert_eq!(response, expected_schema); + assert_eq!( + test_server.take_get_schema_request(), + Some(expected_request) + ); + + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_get_schema_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + let request = FlightDescriptor::new_cmd("my command"); + + let e = Status::unauthenticated("DENIED"); + test_server.set_get_schema_response(Err(e.clone())); + + let response = client.get_schema(request).await.unwrap_err(); + expect_status(response, e); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let infos = vec![ + test_flight_info(&FlightDescriptor::new_cmd("foo")), + test_flight_info(&FlightDescriptor::new_cmd("bar")), + ]; + + let response = infos.iter().map(|i| Ok(i.clone())).collect(); + test_server.set_list_flights_response(response); + + let response_stream = client + .list_flights("query") + .await + .expect("error making request"); + + let expected_response = infos; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + let expected_request = Some(Criteria { + expression: "query".into(), + }); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let response = client.list_flights("query").await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No list_flights response configured"); + expect_status(response, e); + // server still got the request + let expected_request = Some(Criteria { + expression: "query".into(), + }); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_flights_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let response = vec![ + Ok(test_flight_info(&FlightDescriptor::new_cmd("foo"))), + Err(e.clone()), + ]; + test_server.set_list_flights_response(response); + + let response_stream = client + .list_flights("other query") + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + let expected_request = Some(Criteria { + expression: "other query".into(), + }); + assert_eq!(test_server.take_list_flights_request(), expected_request); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let actions = vec![ + ActionType { + r#type: "type 1".into(), + description: "awesomeness".into(), + }, + ActionType { + r#type: "type 2".into(), + description: "more awesomeness".into(), + }, + ]; + + let response = actions.iter().map(|i| Ok(i.clone())).collect(); + test_server.set_list_actions_response(response); + + let response_stream = client.list_actions().await.expect("error making request"); + + let expected_response = actions; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let response = client.list_actions().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No list_actions response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_list_actions_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let response = vec![ + Ok(ActionType { + r#type: "type 1".into(), + description: "awesomeness".into(), + }), + Err(e.clone()), + ]; + test_server.set_list_actions_response(response); + + let response_stream = client.list_actions().await.expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_list_actions_request(), Some(Empty {})); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let bytes = vec![Bytes::from("foo"), Bytes::from("blarg")]; + + let response = bytes + .iter() + .cloned() + .map(arrow_flight::Result::new) + .map(Ok) + .collect(); + test_server.set_do_action_response(response); + + let request = Action::new("action type", "action body"); + + let response_stream = client + .do_action(request.clone()) + .await + .expect("error making request"); + + let expected_response = bytes; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let request = Action::new("action type", "action body"); + + let response = client.do_action(request.clone()).await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + let e = Status::internal("No do_action response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_action_error_in_stream() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let request = Action::new("action type", "action body"); + + let response = vec![Ok(arrow_flight::Result::new("foo")), Err(e.clone())]; + test_server.set_do_action_response(response); + + let response_stream = client + .do_action(request.clone()) + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_action_request(), Some(request)); + ensure_metadata(&client, &test_server); + }) + .await; +} + async fn test_flight_data() -> Vec { let batch = RecordBatch::try_from_iter(vec![( "col", diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index 5060d9d0cc89..b87019d632c4 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -18,15 +18,15 @@ use std::sync::{Arc, Mutex}; use arrow_array::RecordBatch; +use arrow_schema::Schema; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; use arrow_flight::{ encode::FlightDataEncoderBuilder, - error::FlightError, flight_service_server::{FlightService, FlightServiceServer}, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + HandshakeRequest, HandshakeResponse, PutResult, SchemaAsIpc, SchemaResult, Ticket, }; #[derive(Debug, Clone)] @@ -84,7 +84,7 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_get` - pub fn set_do_get_response(&self, response: Vec>) { + pub fn set_do_get_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_get_response.replace(response); } @@ -99,7 +99,7 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_put` - pub fn set_do_put_response(&self, response: Vec>) { + pub fn set_do_put_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_put_response.replace(response); } @@ -114,10 +114,7 @@ impl TestFlightServer { } /// Specify the response returned from the next call to `do_exchange` - pub fn set_do_exchange_response( - &self, - response: Vec>, - ) { + pub fn set_do_exchange_response(&self, response: Vec>) { let mut state = self.state.lock().expect("mutex not poisoned"); state.do_exchange_response.replace(response); } @@ -131,6 +128,69 @@ impl TestFlightServer { .take() } + /// Specify the response returned from the next call to `list_flights` + pub fn set_list_flights_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.list_flights_response.replace(response); + } + + /// Take and return last list_flights request send to the server, + pub fn take_list_flights_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .list_flights_request + .take() + } + + /// Specify the response returned from the next call to `get_schema` + pub fn set_get_schema_response(&self, response: Result) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_schema_response.replace(response); + } + + /// Take and return last get_schema request send to the server, + pub fn take_get_schema_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .get_schema_request + .take() + } + + /// Specify the response returned from the next call to `list_actions` + pub fn set_list_actions_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.list_actions_response.replace(response); + } + + /// Take and return last list_actions request send to the server, + pub fn take_list_actions_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .list_actions_request + .take() + } + + /// Specify the response returned from the next call to `do_action` + pub fn set_do_action_response( + &self, + response: Vec>, + ) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_action_response.replace(response); + } + + /// Take and return last do_action request send to the server, + pub fn take_do_action_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .do_action_request + .take() + } + /// Returns the last metadata from a request received by the server pub fn take_last_request_metadata(&self) -> Option { self.state @@ -162,15 +222,31 @@ struct State { /// The last do_get request received pub do_get_request: Option, /// The next response returned from `do_get` - pub do_get_response: Option>>, + pub do_get_response: Option>>, /// The last do_put request received pub do_put_request: Option>, /// The next response returned from `do_put` - pub do_put_response: Option>>, + pub do_put_response: Option>>, /// The last do_exchange request received pub do_exchange_request: Option>, /// The next response returned from `do_exchange` - pub do_exchange_response: Option>>, + pub do_exchange_response: Option>>, + /// The last list_flights request received + pub list_flights_request: Option, + /// The next response returned from `list_flights` + pub list_flights_response: Option>>, + /// The last get_schema request received + pub get_schema_request: Option, + /// The next response returned from `get_schema` + pub get_schema_response: Option>, + /// The last list_actions request received + pub list_actions_request: Option, + /// The next response returned from `list_actions` + pub list_actions_response: Option>>, + /// The last do_action request received + pub do_action_request: Option, + /// The next response returned from `do_action` + pub do_action_response: Option>>, /// The last request headers received pub last_request_metadata: Option, } @@ -213,9 +289,21 @@ impl FlightService for TestFlightServer { async fn list_flights( &self, - _request: Request, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Implement list_flights")) + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.list_flights_request = Some(request.into_inner()); + + let flights: Vec<_> = state + .list_flights_response + .take() + .ok_or_else(|| Status::internal("No list_flights response configured"))?; + + let flights_stream = futures::stream::iter(flights); + + Ok(Response::new(flights_stream.boxed())) } async fn get_flight_info( @@ -233,9 +321,22 @@ impl FlightService for TestFlightServer { async fn get_schema( &self, - _request: Request, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Implement get_schema")) + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + state.get_schema_request = Some(request.into_inner()); + let schema = state.get_schema_response.take().unwrap_or_else(|| { + Err(Status::internal("No get_schema response configured")) + })?; + + // encode the schema + let options = arrow_ipc::writer::IpcWriteOptions::default(); + let response: SchemaResult = SchemaAsIpc::new(&schema, &options) + .try_into() + .expect("Error encoding schema"); + + Ok(Response::new(response)) } async fn do_get( @@ -252,7 +353,7 @@ impl FlightService for TestFlightServer { .take() .ok_or_else(|| Status::internal("No do_get response configured"))?; - let batch_stream = futures::stream::iter(batches); + let batch_stream = futures::stream::iter(batches).map_err(Into::into); let stream = FlightDataEncoderBuilder::new() .build(batch_stream) @@ -284,16 +385,40 @@ impl FlightService for TestFlightServer { async fn do_action( &self, - _request: Request, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Implement do_action")) + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_action_request = Some(request.into_inner()); + + let results: Vec<_> = state + .do_action_response + .take() + .ok_or_else(|| Status::internal("No do_action response configured"))?; + + let results_stream = futures::stream::iter(results); + + Ok(Response::new(results_stream.boxed())) } async fn list_actions( &self, - _request: Request, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Implement list_actions")) + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.list_actions_request = Some(request.into_inner()); + + let actions: Vec<_> = state + .list_actions_response + .take() + .ok_or_else(|| Status::internal("No list_actions response configured"))?; + + let action_stream = futures::stream::iter(actions); + + Ok(Response::new(action_stream.boxed())) } async fn do_exchange(