-
Notifications
You must be signed in to change notification settings - Fork 842
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,7 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use std::{ | ||
sync::{Arc, Mutex}, | ||
task::Poll, | ||
}; | ||
use std::task::Poll; | ||
|
||
use crate::{ | ||
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action, | ||
|
@@ -31,7 +28,7 @@ use futures::{ | |
future::ready, | ||
ready, | ||
stream::{self, BoxStream}, | ||
Stream, StreamExt, TryStreamExt, | ||
FutureExt, Stream, StreamExt, TryStreamExt, | ||
}; | ||
use tonic::{metadata::MetadataMap, transport::Channel}; | ||
|
||
|
@@ -310,21 +307,36 @@ impl FlightClient { | |
&mut self, | ||
request: S, | ||
) -> Result<BoxStream<'static, Result<PutResult>>> { | ||
let ok_stream = FallibleStream::new(request.boxed()); | ||
let builder = ok_stream.builder(); | ||
let (sender, mut receiver) = futures::channel::oneshot::channel(); | ||
|
||
// Intercepts client errors and sends them to the oneshot channel above | ||
let mut request = Box::pin(request); // Pin to heap | ||
This comment has been minimized.
Sorry, something went wrong. |
||
let mut sender = Some(sender); // Wrap into Option so can be taken | ||
let request_stream = futures::stream::poll_fn(move |cx| { | ||
Poll::Ready(match ready!(request.poll_next_unpin(cx)) { | ||
Some(Ok(data)) => Some(data), | ||
Some(Err(e)) => { | ||
let _ = sender.take().unwrap().send(e); | ||
This comment has been minimized.
Sorry, something went wrong.
alamb
Contributor
|
||
None | ||
} | ||
None => None, | ||
}) | ||
}); | ||
|
||
// send ok result to the server | ||
let request = self.make_request(ok_stream); | ||
let request = self.make_request(request_stream); | ||
let mut response_stream = self.inner.do_put(request).await?.into_inner(); | ||
|
||
let response_stream = self | ||
.inner | ||
.do_put(request) | ||
.await? | ||
.into_inner() | ||
.map_err(FlightError::Tonic); | ||
// Forwards errors from the error oneshot with priority over responses from server | ||
let error_stream = futures::stream::poll_fn(move |cx| { | ||
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { | ||
return Poll::Ready(Some(Err(err))); | ||
} | ||
let next = ready!(response_stream.poll_next_unpin(cx)); | ||
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) | ||
}); | ||
|
||
// combine the response from the server and any error from the client | ||
Ok(builder.build(response_stream.boxed())) | ||
Ok(error_stream.boxed()) | ||
} | ||
|
||
/// Make a `DoExchange` call to the server with the provided | ||
|
@@ -549,118 +561,3 @@ impl FlightClient { | |
request | ||
} | ||
} | ||
|
||
/// A stream that reads `Results`, passing along Ok variants, | ||
/// and saving any Errors seen to be forward along with responses | ||
/// | ||
/// If the input stream produces an an error, the error is saved in `err` | ||
/// and this stream is ended (the inner is not pollled any more) | ||
/// | ||
/// The setup of copying errors to result stream looks like this: | ||
/// | ||
/// ```text | ||
/// input: ---> (Stream of Result<FlightData>) ---- (Stream of FlightData) ---- network ----> Server | ||
/// | | | ||
/// | (errors copied to output) | | ||
/// v | | ||
/// output: <-- (Stream of Result<FlightData>) <--- network --(Stream of Result<PutResult>)------+ | ||
/// ``` | ||
struct FallibleStream { | ||
input_stream: BoxStream<'static, Result<FlightData>>, | ||
err: Arc<Mutex<Option<FlightError>>>, | ||
done: bool, | ||
} | ||
|
||
impl FallibleStream { | ||
fn new(input_stream: BoxStream<'static, Result<FlightData>>) -> Self { | ||
Self { | ||
input_stream, | ||
done: false, | ||
err: Arc::new(Mutex::new(None)), | ||
} | ||
} | ||
|
||
/// Returns a builder that wraps result streams and injects error | ||
fn builder(&self) -> StreamWrapperBuilder { | ||
StreamWrapperBuilder { | ||
maybe_err: Arc::clone(&self.err), | ||
} | ||
} | ||
} | ||
|
||
impl Stream for FallibleStream { | ||
type Item = FlightData; | ||
|
||
fn poll_next( | ||
mut self: std::pin::Pin<&mut Self>, | ||
cx: &mut std::task::Context<'_>, | ||
) -> Poll<Option<Self::Item>> { | ||
if self.done { | ||
return Poll::Ready(None); | ||
} | ||
|
||
match ready!(self.input_stream.poll_next_unpin(cx)) { | ||
Some(data) => match data { | ||
Ok(ok) => Poll::Ready(Some(ok)), | ||
Err(e) => { | ||
*self.err.lock().expect("non poisoned") = Some(e); | ||
self.done = true; | ||
Poll::Ready(None) | ||
} | ||
}, | ||
// input stream was done | ||
None => { | ||
self.done = true; | ||
Poll::Ready(None) | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// A builder for wrapping server result streams that return either | ||
/// the error from the provided client stream or the error from the server | ||
struct StreamWrapperBuilder { | ||
maybe_err: Arc<Mutex<Option<FlightError>>>, | ||
} | ||
|
||
impl StreamWrapperBuilder { | ||
/// wraps response stream to return items from response_stream or | ||
/// the client stream error, if any | ||
/// Produce a stream that reads results from the server, first | ||
/// checking to see if the client stream generated an error | ||
fn build( | ||
self, | ||
response_stream: BoxStream<'static, Result<PutResult>>, | ||
) -> BoxStream<'static, Result<PutResult>> { | ||
let state = StreamAndError { | ||
maybe_err: self.maybe_err, | ||
response_stream, | ||
}; | ||
|
||
futures::stream::unfold(state, |mut state| async move { | ||
state.next().await.map(|item| (item, state)) | ||
}) | ||
.boxed() | ||
} | ||
} | ||
|
||
struct StreamAndError { | ||
// error from a FallableStream | ||
maybe_err: Arc<Mutex<Option<FlightError>>>, | ||
response_stream: BoxStream<'static, Result<PutResult>>, | ||
} | ||
|
||
impl StreamAndError { | ||
/// get the next result to pass along | ||
async fn next(&mut self) -> Option<Result<PutResult>> { | ||
// if the client made an error return that | ||
let next_item = self.maybe_err.lock().expect("non poisoned").take(); | ||
if let Some(e) = next_item { | ||
Some(Err(e)) | ||
} | ||
// otherwise return the next item from the server, if any | ||
else { | ||
self.response_stream.next().await | ||
} | ||
} | ||
} |
It is probably possible to avoid this with some pin projection dance, given we are performing networked IO I doubt this really matters and simplifies things