diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 7a239ff085fc..fe1292fcff6e 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -15,10 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{ - sync::{Arc, Mutex}, - task::Poll, -}; +use std::task::Poll; use crate::{ decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action, @@ -31,7 +28,7 @@ use futures::{ future::ready, ready, stream::{self, BoxStream}, - Stream, StreamExt, TryStreamExt, + FutureExt, Stream, StreamExt, TryStreamExt, }; use tonic::{metadata::MetadataMap, transport::Channel}; @@ -310,21 +307,36 @@ impl FlightClient { &mut self, request: S, ) -> Result>> { - let ok_stream = FallibleStream::new(request.boxed()); - let builder = ok_stream.builder(); + let (sender, mut receiver) = futures::channel::oneshot::channel(); + + // Intercepts client errors and sends them to the oneshot channel above + let mut request = Box::pin(request); // Pin to heap + let mut sender = Some(sender); // Wrap into Option so can be taken + let request_stream = futures::stream::poll_fn(move |cx| { + Poll::Ready(match ready!(request.poll_next_unpin(cx)) { + Some(Ok(data)) => Some(data), + Some(Err(e)) => { + let _ = sender.take().unwrap().send(e); + None + } + None => None, + }) + }); - // send ok result to the server - let request = self.make_request(ok_stream); + let request = self.make_request(request_stream); + let mut response_stream = self.inner.do_put(request).await?.into_inner(); - let response_stream = self - .inner - .do_put(request) - .await? - .into_inner() - .map_err(FlightError::Tonic); + // Forwards errors from the error oneshot with priority over responses from server + let error_stream = futures::stream::poll_fn(move |cx| { + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + } + let next = ready!(response_stream.poll_next_unpin(cx)); + Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + }); // combine the response from the server and any error from the client - Ok(builder.build(response_stream.boxed())) + Ok(error_stream.boxed()) } /// Make a `DoExchange` call to the server with the provided @@ -549,118 +561,3 @@ impl FlightClient { request } } - -/// A stream that reads `Results`, passing along Ok variants, -/// and saving any Errors seen to be forward along with responses -/// -/// If the input stream produces an an error, the error is saved in `err` -/// and this stream is ended (the inner is not pollled any more) -/// -/// The setup of copying errors to result stream looks like this: -/// -/// ```text -/// input: ---> (Stream of Result) ---- (Stream of FlightData) ---- network ----> Server -/// | | -/// | (errors copied to output) | -/// v | -/// output: <-- (Stream of Result) <--- network --(Stream of Result)------+ -/// ``` -struct FallibleStream { - input_stream: BoxStream<'static, Result>, - err: Arc>>, - done: bool, -} - -impl FallibleStream { - fn new(input_stream: BoxStream<'static, Result>) -> Self { - Self { - input_stream, - done: false, - err: Arc::new(Mutex::new(None)), - } - } - - /// Returns a builder that wraps result streams and injects error - fn builder(&self) -> StreamWrapperBuilder { - StreamWrapperBuilder { - maybe_err: Arc::clone(&self.err), - } - } -} - -impl Stream for FallibleStream { - type Item = FlightData; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if self.done { - return Poll::Ready(None); - } - - match ready!(self.input_stream.poll_next_unpin(cx)) { - Some(data) => match data { - Ok(ok) => Poll::Ready(Some(ok)), - Err(e) => { - *self.err.lock().expect("non poisoned") = Some(e); - self.done = true; - Poll::Ready(None) - } - }, - // input stream was done - None => { - self.done = true; - Poll::Ready(None) - } - } - } -} - -/// A builder for wrapping server result streams that return either -/// the error from the provided client stream or the error from the server -struct StreamWrapperBuilder { - maybe_err: Arc>>, -} - -impl StreamWrapperBuilder { - /// wraps response stream to return items from response_stream or - /// the client stream error, if any - /// Produce a stream that reads results from the server, first - /// checking to see if the client stream generated an error - fn build( - self, - response_stream: BoxStream<'static, Result>, - ) -> BoxStream<'static, Result> { - let state = StreamAndError { - maybe_err: self.maybe_err, - response_stream, - }; - - futures::stream::unfold(state, |mut state| async move { - state.next().await.map(|item| (item, state)) - }) - .boxed() - } -} - -struct StreamAndError { - // error from a FallableStream - maybe_err: Arc>>, - response_stream: BoxStream<'static, Result>, -} - -impl StreamAndError { - /// get the next result to pass along - async fn next(&mut self) -> Option> { - // if the client made an error return that - let next_item = self.maybe_err.lock().expect("non poisoned").take(); - if let Some(e) = next_item { - Some(Err(e)) - } - // otherwise return the next item from the server, if any - else { - self.response_stream.next().await - } - } -}