Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fallible streams for FlightClient::do_put #3464

Merged
merged 8 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 130 additions & 7 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use std::{
sync::{Arc, Mutex},
task::Poll,
};

use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
Expand All @@ -24,6 +29,7 @@ use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
ready,
stream::{self, BoxStream},
Stream, StreamExt, TryStreamExt,
};
Expand Down Expand Up @@ -262,6 +268,15 @@ impl FlightClient {
/// [`Stream`](futures::Stream) of [`FlightData`] and returning a
/// stream of [`PutResult`].
///
/// # Note
///
/// The input stream is [`Result`] so that this can be connected
/// to a streaming data source (such as [`FlightDataEncoder`])
/// without having to buffer. If the input stream returns an error
/// that error will not be sent to the server, instead it will be
/// placed into the result stream and the server connection
/// terminated.
///
/// # Example:
/// ```no_run
/// # async fn run() {
Expand All @@ -279,9 +294,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_put requires FlightData
/// .map(|batch|batch.unwrap());
/// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `PutResult`
/// let response: Vec<PutResult>= client
Expand All @@ -293,20 +306,25 @@ impl FlightClient {
/// .expect("error calling do_put");
/// # }
/// ```
pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let request = self.make_request(request);
let ok_stream = FallibleStream::new(request.boxed());
let builder = ok_stream.builder();

let response = self
// send ok result to the server
let request = self.make_request(ok_stream);

let response_stream = self
.inner
.do_put(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);

Ok(response.boxed())
// combine the response from the server and any error from the client
Ok(builder.build(response_stream.boxed()))
}

/// Make a `DoExchange` call to the server with the provided
Expand Down Expand Up @@ -531,3 +549,108 @@ impl FlightClient {
request
}
}

/// A stream that reads `Results`, and passes along the OK variants,
/// and saves any Errors seen to be forward along with responses
///
/// If the input stream produces an an error, the error is saved in `err`
/// and this stream is ended (the inner is not pollled any more)
struct FallibleStream {
input_stream: BoxStream<'static, Result<FlightData>>,
err: Arc<Mutex<Option<FlightError>>>,
done: bool,
}

impl FallibleStream {
fn new(input_stream: BoxStream<'static, Result<FlightData>>) -> Self {
Self {
input_stream,
done: false,
err: Arc::new(Mutex::new(None)),
}
}

/// Returns a builder for wrapping result streams
fn builder(&self) -> StreamWrapperBuilder {
StreamWrapperBuilder {
maybe_err: Arc::clone(&self.err),
}
}
}

impl Stream for FallibleStream {
type Item = FlightData;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}

match ready!(self.input_stream.poll_next_unpin(cx)) {
Some(data) => match data {
Ok(ok) => Poll::Ready(Some(ok)),
Err(e) => {
*self.err.lock().expect("non poisoned") = Some(e);
self.done = true;
Poll::Ready(None)
}
},
// input stream was done
None => {
self.done = true;
Poll::Ready(None)
}
}
}
}

/// A builder for wrapping server result streams that return either
/// the error from the provided client stream or the error from the server
struct StreamWrapperBuilder {
maybe_err: Arc<Mutex<Option<FlightError>>>,
}

impl StreamWrapperBuilder {
/// wraps response stream to return items from response_stream or
/// the client stream error, if any
/// Produce a stream that reads results from the server, first
/// checking to see if the client stream generated an error
fn build(
self,
response_stream: BoxStream<'static, Result<PutResult>>,
) -> BoxStream<'static, Result<PutResult>> {
let state = StreamAndError {
maybe_err: self.maybe_err,
response_stream,
};

futures::stream::unfold(state, |mut state| async move {
state.next().await.map(|item| (item, state))
})
.boxed()
}
}

struct StreamAndError {
// error from a FallableStream
maybe_err: Arc<Mutex<Option<FlightError>>>,
response_stream: BoxStream<'static, Result<PutResult>>,
}

impl StreamAndError {
/// get the next result to pass along
async fn next(&mut self) -> Option<Result<PutResult>> {
// if the client made an error return that
let next_item = self.maybe_err.lock().expect("non poisoned").take();
if let Some(e) = next_item {
Some(Err(e))
}
// otherwise return the next item from the server, if any
else {
self.response_stream.next().await
}
}
}
99 changes: 92 additions & 7 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ async fn test_do_put() {
test_server
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

Expand All @@ -266,15 +268,15 @@ async fn test_do_put() {
}

#[tokio::test]
async fn test_do_put_error() {
async fn test_do_put_error_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let input_flight_data = test_flight_data().await;

let response = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.await;
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
Expand All @@ -290,7 +292,7 @@ async fn test_do_put_error() {
}

#[tokio::test]
async fn test_do_put_error_stream() {
async fn test_do_put_error_stream_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

Expand All @@ -307,8 +309,10 @@ async fn test_do_put_error_stream() {

test_server.set_do_put_response(response);

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

Expand All @@ -326,6 +330,87 @@ async fn test_do_put_error_stream() {
.await;
}

#[tokio::test]
async fn test_do_put_error_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e = Status::invalid_argument("bad arg: client");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these two tests to show what is going on -- aka that the error provided to the client do_put call need to get to the returned stream even though the error did not come from the server

.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));

// server responds with one good message
let response = vec![Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
})];
test_server.set_do_put_response(response);

let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_put_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));

// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_put_response(response);

let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_exchange() {
do_test(|test_server, mut client| async move {
Expand Down