Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fallible streams for FlightClient::do_put #3464

Merged
merged 8 commits into from
Feb 23, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 163 additions & 6 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
@@ -15,6 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use std::{
sync::{Arc, Mutex},
task::Poll,
};

use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
@@ -262,6 +267,15 @@ impl FlightClient {
/// [`Stream`](futures::Stream) of [`FlightData`] and returning a
/// stream of [`PutResult`].
///
/// # Note
///
/// The input stream is [`Result`] so that this can be connected
/// to a streaming data source (such as [`FlightDataEncoder`])
/// without having to buffer. If the input stream returns an error
/// that error will not be sent to the server, instead it will be
/// placed into the result stream and the server connection
/// terminated.
///
/// # Example:
/// ```no_run
/// # async fn run() {
@@ -279,9 +293,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_put requires FlightData
/// .map(|batch|batch.unwrap());
/// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `PutResult`
/// let response: Vec<PutResult>= client
@@ -293,11 +305,14 @@ impl FlightClient {
/// .expect("error calling do_put");
/// # }
/// ```
pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let request = self.make_request(request);
let (ok_stream, err_stream) = split_stream(request.boxed());

// send ok result to the server
let request = self.make_request(ok_stream);

let response = self
.inner
@@ -306,7 +321,10 @@ impl FlightClient {
.into_inner()
.map_err(FlightError::Tonic);

Ok(response.boxed())
let err_stream = err_stream.map(Err);

// combine the response from the server and any error from the client
Ok(futures::stream::select_all([response.boxed(), err_stream.boxed()]).boxed())
}

/// Make a `DoExchange` call to the server with the provided
@@ -531,3 +549,142 @@ impl FlightClient {
request
}
}

// splits the input stream into an invallable flight data stream and errors errors
//
// TODO generify
fn split_stream(
alamb marked this conversation as resolved.
Show resolved Hide resolved
input_stream: BoxStream<'static, Result<FlightData>>,
) -> (SplitStreamOk, SplitStreamErr) {
let inner = SplitStream {
input_stream,
next_ok: None,
next_err: None,
done: false,
};
let inner = Arc::new(Mutex::new(inner));

let ok_stream = SplitStreamOk {
inner: Arc::clone(&inner),
};

let err_stream = SplitStreamErr {
inner: Arc::clone(&inner),
};

(ok_stream, err_stream)
}

struct SplitStream {
input_stream: BoxStream<'static, Result<FlightData>>,
next_ok: Option<FlightData>,
next_err: Option<FlightError>,
done: bool,
}

impl SplitStream {
Copy link
Contributor

@Dandandan Dandandan Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of splitting - this stream can stop whenever the first error appears and keep that first error if any, so it can be retrieved.

// returns the next ok item ready if any
fn poll_next_ok(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<FlightData>> {
loop {
if let Some(flight_data) = self.next_ok.take() {
return Poll::Ready(Some(flight_data));
}

if self.done {
return Poll::Ready(None);
}

// try to get another item from the inner stream
if !self.maybe_read(cx) {
return Poll::Pending;
}
}
}

// returns the next err item
fn poll_next_err(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<FlightError>> {
loop {
if let Some(e) = self.next_err.take() {
return Poll::Ready(Some(e));
}

if self.done {
return Poll::Ready(None);
}

// try to get another item from the inner stream
if !self.maybe_read(cx) {
return Poll::Pending;
}
}
}

// if we have space for both ok and error, take next from inner stream
// returns true if read an item false otherwise
fn maybe_read(&mut self, cx: &mut std::task::Context<'_>) -> bool {
// if there is space for ok and err, take next
if self.next_ok.is_some() || self.next_err.is_some() {
// can't take next until there is space
return false;
}

let next = match self.input_stream.poll_next_unpin(cx) {
Poll::Pending => return false,
Poll::Ready(next) => next,
};

match next {
Some(Ok(flight_data)) => {
self.next_ok = Some(flight_data);
}
Some(Err(e)) => {
self.next_err = Some(e);
// stop reading once we see an error
self.done = true;
}
None => {
self.done = true;
}
};

true
}
}

/// returns only the OK responses from a stream of results
struct SplitStreamErr {
Copy link
Contributor

@Dandandan Dandandan Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


struct FallibleStream {
    input_stream: BoxStream<'static, Result<FlightData>>,
    err: Option<FlightError>,
}

impl Stream for FallibleStream {
    type Item = FlightData;

    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
        match self.input_stream.poll_next_unpin(cx) {
            std::task::Poll::Ready(res) => match res {
                Some(data) => match data {
                    Ok(ok) => std::task::Poll::Ready(Some(ok)),
                    Err(e) => {
                        self.err =  Some(e);
                        std::task::Poll::Ready(None)
                    },
                } ,
                None => std::task::Poll::Ready(None),
            },
            std::task::Poll::Pending => std::task::Poll::Pending,
        }

    }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense @Dandandan 👍 thank you for the suggestion

inner: Arc<Mutex<SplitStream>>,
}

impl Stream for SplitStreamOk {
type Item = FlightData;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.lock().unwrap().poll_next_ok(cx)
}
}

/// returns only the Err responses from a stream of results
struct SplitStreamOk {
inner: Arc<Mutex<SplitStream>>,
}

impl Stream for SplitStreamErr {
type Item = FlightError;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.inner.lock().unwrap().poll_next_err(cx)
}
}
14 changes: 9 additions & 5 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
@@ -252,8 +252,10 @@ async fn test_do_put() {
test_server
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

@@ -276,9 +278,9 @@ async fn test_do_put_error() {

let input_flight_data = test_flight_data().await;

let response = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.await;
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
@@ -309,8 +311,10 @@ async fn test_do_put_error_stream() {

test_server.set_do_put_response(response);

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");