Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Feb 13, 2023
1 parent 67723b6 commit 914948b
Showing 1 changed file with 28 additions and 131 deletions.
159 changes: 28 additions & 131 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{
sync::{Arc, Mutex},
task::Poll,
};
use std::task::Poll;

use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
Expand All @@ -31,7 +28,7 @@ use futures::{
future::ready,
ready,
stream::{self, BoxStream},
Stream, StreamExt, TryStreamExt,
FutureExt, Stream, StreamExt, TryStreamExt,
};
use tonic::{metadata::MetadataMap, transport::Channel};

Expand Down Expand Up @@ -310,21 +307,36 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let ok_stream = FallibleStream::new(request.boxed());
let builder = ok_stream.builder();
let (sender, mut receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap

This comment has been minimized.

Copy link
@tustvold

tustvold Feb 13, 2023

Author Contributor

It is probably possible to avoid this with some pin projection dance, given we are performing networked IO I doubt this really matters and simplifies things

let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);

This comment has been minimized.

Copy link
@alamb

alamb Feb 13, 2023

Contributor

I wonder if the stream could be polled again after it returned a None. If so, would this unwrap cause a panic?

This comment has been minimized.

Copy link
@tustvold

tustvold Feb 13, 2023

Author Contributor

That would be a violation of the stream contract, polling a stream after it has returned None is implementation defined and panicking is fine

This comment has been minimized.

Copy link
@alamb

alamb Feb 13, 2023

Contributor

And yet https://docs.rs/futures/0.3.26/futures/stream/trait.StreamExt.html#method.fuse exists so I thought it could happen? Maybe I am misremembering

This comment has been minimized.

Copy link
@tustvold

tustvold Feb 13, 2023

Author Contributor

Fuse exists to give it defined semantics, the onus is on the caller, not the stream implementation to ensure they either don't poll after None, or use Fuse

See https://docs.rs/futures/latest/futures/stream/trait.Stream.html#panics

Once a stream has finished (returned Ready(None) from poll_next), calling its poll_next method again may panic, block forever, or cause other kinds of problems; the Stream trait places no requirements on the effects of such a call.

This comment has been minimized.

Copy link
@alamb

alamb Feb 14, 2023

Contributor

Sorry -- I was trying to be pragmatic. Since we know some callers (ahem DataFusion) sometimes do this anyways, I thought it might be nicer to handle. I am fine with this approach as well

None
}
None => None,
})
});

// send ok result to the server
let request = self.make_request(ok_stream);
let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();

let response_stream = self
.inner
.do_put(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});

// combine the response from the server and any error from the client
Ok(builder.build(response_stream.boxed()))
Ok(error_stream.boxed())
}

/// Make a `DoExchange` call to the server with the provided
Expand Down Expand Up @@ -549,118 +561,3 @@ impl FlightClient {
request
}
}

/// A stream that reads `Results`, passing along Ok variants,
/// and saving any Errors seen to be forward along with responses
///
/// If the input stream produces an an error, the error is saved in `err`
/// and this stream is ended (the inner is not pollled any more)
///
/// The setup of copying errors to result stream looks like this:
///
/// ```text
/// input: ---> (Stream of Result<FlightData>) ---- (Stream of FlightData) ---- network ----> Server
/// | |
/// | (errors copied to output) |
/// v |
/// output: <-- (Stream of Result<FlightData>) <--- network --(Stream of Result<PutResult>)------+
/// ```
struct FallibleStream {
input_stream: BoxStream<'static, Result<FlightData>>,
err: Arc<Mutex<Option<FlightError>>>,
done: bool,
}

impl FallibleStream {
fn new(input_stream: BoxStream<'static, Result<FlightData>>) -> Self {
Self {
input_stream,
done: false,
err: Arc::new(Mutex::new(None)),
}
}

/// Returns a builder that wraps result streams and injects error
fn builder(&self) -> StreamWrapperBuilder {
StreamWrapperBuilder {
maybe_err: Arc::clone(&self.err),
}
}
}

impl Stream for FallibleStream {
type Item = FlightData;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}

match ready!(self.input_stream.poll_next_unpin(cx)) {
Some(data) => match data {
Ok(ok) => Poll::Ready(Some(ok)),
Err(e) => {
*self.err.lock().expect("non poisoned") = Some(e);
self.done = true;
Poll::Ready(None)
}
},
// input stream was done
None => {
self.done = true;
Poll::Ready(None)
}
}
}
}

/// A builder for wrapping server result streams that return either
/// the error from the provided client stream or the error from the server
struct StreamWrapperBuilder {
maybe_err: Arc<Mutex<Option<FlightError>>>,
}

impl StreamWrapperBuilder {
/// wraps response stream to return items from response_stream or
/// the client stream error, if any
/// Produce a stream that reads results from the server, first
/// checking to see if the client stream generated an error
fn build(
self,
response_stream: BoxStream<'static, Result<PutResult>>,
) -> BoxStream<'static, Result<PutResult>> {
let state = StreamAndError {
maybe_err: self.maybe_err,
response_stream,
};

futures::stream::unfold(state, |mut state| async move {
state.next().await.map(|item| (item, state))
})
.boxed()
}
}

struct StreamAndError {
// error from a FallableStream
maybe_err: Arc<Mutex<Option<FlightError>>>,
response_stream: BoxStream<'static, Result<PutResult>>,
}

impl StreamAndError {
/// get the next result to pass along
async fn next(&mut self) -> Option<Result<PutResult>> {
// if the client made an error return that
let next_item = self.maybe_err.lock().expect("non poisoned").take();
if let Some(e) = next_item {
Some(Err(e))
}
// otherwise return the next item from the server, if any
else {
self.response_stream.next().await
}
}
}

0 comments on commit 914948b

Please sign in to comment.