From 852651e03484ff99eba4c0937731dcf1afdda3ff Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 23 Aug 2023 15:59:25 +0200 Subject: [PATCH] feat: expose DoGet response headers & trailers --- arrow-flight/Cargo.toml | 3 + arrow-flight/src/client.rs | 20 ++- arrow-flight/src/decode.rs | 43 +++++- arrow-flight/src/lib.rs | 3 + arrow-flight/src/trailers.rs | 92 +++++++++++++ arrow-flight/tests/client.rs | 31 ++++- arrow-flight/tests/common/server.rs | 6 +- arrow-flight/tests/common/trailers_layer.rs | 138 ++++++++++++++++++++ 8 files changed, 318 insertions(+), 18 deletions(-) create mode 100644 arrow-flight/src/trailers.rs create mode 100644 arrow-flight/tests/common/trailers_layer.rs diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 3ed426a21fab..1a53dbddb13d 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -67,6 +67,9 @@ cli = ["arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "t [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } assert_cmd = "2.0.8" +http = "0.2.9" +http-body = "0.4.5" +pin-project-lite = "0.2" tempfile = "3.3" tokio-stream = { version = "0.1", features = ["net"] } tower = "0.4.13" diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 2c952fb3bfbf..2df55545218f 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -18,9 +18,9 @@ use std::task::Poll; use crate::{ - decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action, - ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, - HandshakeRequest, PutResult, Ticket, + decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, + trailers::extract_trailers, Action, ActionType, Criteria, Empty, FlightData, + FlightDescriptor, FlightInfo, HandshakeRequest, PutResult, Ticket, }; use arrow_schema::Schema; use bytes::Bytes; @@ -204,16 +204,14 @@ impl FlightClient { pub async fn do_get(&mut self, ticket: Ticket) -> Result { let request = self.make_request(ticket); - let response_stream = self - .inner - .do_get(request) - .await? - .into_inner() - .map_err(FlightError::Tonic); + let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts(); + let (response_stream, trailers) = extract_trailers(response_stream); Ok(FlightRecordBatchStream::new_from_flight_data( - response_stream, - )) + response_stream.map_err(FlightError::Tonic), + ) + .with_headers(md) + .with_trailers(trailers)) } /// Make a `GetFlightInfo` call to the server with the provided diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index df74923332e3..2c181053f55a 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{utils::flight_data_to_arrow_batch, FlightData}; +use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData}; use arrow_array::{ArrayRef, RecordBatch}; use arrow_buffer::Buffer; use arrow_schema::{Schema, SchemaRef}; @@ -24,6 +24,7 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; use std::{ collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll, }; +use tonic::metadata::MetadataMap; use crate::error::{FlightError, Result}; @@ -82,13 +83,23 @@ use crate::error::{FlightError, Result}; /// ``` #[derive(Debug)] pub struct FlightRecordBatchStream { + /// Optional grpc header metadata. + headers: MetadataMap, + + /// Optional grpc trailer metadata. + trailers: Option, + inner: FlightDataDecoder, } impl FlightRecordBatchStream { /// Create a new [`FlightRecordBatchStream`] from a decoded stream pub fn new(inner: FlightDataDecoder) -> Self { - Self { inner } + Self { + inner, + headers: MetadataMap::default(), + trailers: None, + } } /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`] @@ -98,9 +109,36 @@ impl FlightRecordBatchStream { { Self { inner: FlightDataDecoder::new(inner), + headers: MetadataMap::default(), + trailers: None, + } + } + + /// Record response headers. + pub fn with_headers(self, headers: MetadataMap) -> Self { + Self { headers, ..self } + } + + /// Record response trailers. + pub fn with_trailers(self, trailers: LazyTrailers) -> Self { + Self { + trailers: Some(trailers), + ..self } } + /// Headers attached to this stream. + pub fn headers(&self) -> &MetadataMap { + &self.headers + } + + /// Trailers attached to this stream. + /// + /// This is only filled when the entire stream was consumed. + pub fn trailers(&self) -> Option { + self.trailers.as_ref().and_then(|trailers| trailers.get()) + } + /// Has a message defining the schema been received yet? #[deprecated = "use schema().is_some() instead"] pub fn got_schema(&self) -> bool { @@ -117,6 +155,7 @@ impl FlightRecordBatchStream { self.inner } } + impl futures::Stream for FlightRecordBatchStream { type Item = Result; diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 4163f2ceaa27..04edf266389c 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -111,6 +111,9 @@ pub use gen::Result; pub use gen::SchemaResult; pub use gen::Ticket; +/// Helper to extract HTTP/gRPC trailers from a tonic stream. +mod trailers; + pub mod utils; #[cfg(feature = "flight-sql-experimental")] diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs new file mode 100644 index 000000000000..aba652ad64b7 --- /dev/null +++ b/arrow-flight/src/trailers.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use futures::{ready, FutureExt, Stream, StreamExt}; +use tonic::{metadata::MetadataMap, Status, Streaming}; + +/// Extract trailers from [`Streaming`] [tonic] response. +pub fn extract_trailers(s: Streaming) -> (ExtractTrailersStream, LazyTrailers) { + let trailers: SharedTrailers = Default::default(); + let stream = ExtractTrailersStream { + inner: s, + trailers: Arc::clone(&trailers), + }; + let lazy_trailers = LazyTrailers { trailers }; + (stream, lazy_trailers) +} + +type SharedTrailers = Arc>>; + +/// [Stream] that stores the gRPC trailers into [`LazyTrailers`]. +/// +/// See [`extract_trailers`] for construction. +#[derive(Debug)] +pub struct ExtractTrailersStream { + inner: Streaming, + trailers: SharedTrailers, +} + +impl Stream for ExtractTrailersStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = ready!(self.inner.poll_next_unpin(cx)); + + if res.is_none() { + // stream exhausted => trailers should available + if let Some(trailers) = self + .inner + .trailers() + .now_or_never() + .and_then(|res| res.ok()) + .flatten() + { + *self.trailers.lock().expect("poisoned") = Some(trailers); + } + } + + Poll::Ready(res) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +/// gRPC trailers that are extracted by [`ExtractTrailersStream`]. +/// +/// See [`extract_trailers`] for construction. +#[derive(Debug)] +pub struct LazyTrailers { + trailers: SharedTrailers, +} + +impl LazyTrailers { + /// gRPC trailers that are known at the end of a stream. + pub fn get(&self) -> Option { + self.trailers.lock().expect("poisoned").clone() + } +} diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 8ea542879a27..2d1dc76ae74b 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -19,6 +19,7 @@ mod common { pub mod server; + pub mod trailers_layer; } use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ @@ -28,7 +29,7 @@ use arrow_flight::{ }; use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; -use common::server::TestFlightServer; +use common::{server::TestFlightServer, trailers_layer::TrailersLayer}; use futures::{Future, StreamExt, TryStreamExt}; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{ @@ -158,18 +159,39 @@ async fn test_do_get() { let response = vec![Ok(batch.clone())]; test_server.set_do_get_response(response); - let response_stream = client + let mut response_stream = client .do_get(ticket.clone()) .await .expect("error making request"); + assert_eq!( + response_stream + .headers() + .get("test-resp-header") + .expect("header exists") + .to_str() + .unwrap(), + "some_val", + ); + let expected_response = vec![batch]; - let response: Vec<_> = response_stream + let response: Vec<_> = (&mut response_stream) .try_collect() .await .expect("Error streaming data"); - assert_eq!(response, expected_response); + + assert_eq!( + response_stream + .trailers() + .expect("stream exhausted") + .get("test-trailer") + .expect("trailer exists") + .to_str() + .unwrap(), + "trailer_val", + ); + assert_eq!(test_server.take_do_get_request(), Some(ticket)); ensure_metadata(&client, &test_server); }) @@ -932,6 +954,7 @@ impl TestFixture { let serve_future = tonic::transport::Server::builder() .timeout(server_timeout) + .layer(TrailersLayer) .add_service(test_server.service()) .serve_with_incoming_shutdown( tokio_stream::wrappers::TcpListenerStream::new(listener), diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index b87019d632c4..c575d12bbf52 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -359,7 +359,11 @@ impl FlightService for TestFlightServer { .build(batch_stream) .map_err(Into::into); - Ok(Response::new(stream.boxed())) + let mut resp = Response::new(stream.boxed()); + resp.metadata_mut() + .insert("test-resp-header", "some_val".parse().unwrap()); + + Ok(resp) } async fn do_put( diff --git a/arrow-flight/tests/common/trailers_layer.rs b/arrow-flight/tests/common/trailers_layer.rs new file mode 100644 index 000000000000..9e6be0dcf0da --- /dev/null +++ b/arrow-flight/tests/common/trailers_layer.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::ready; +use http::{HeaderValue, Request, Response}; +use http_body::SizeHint; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +#[derive(Debug, Copy, Clone, Default)] +pub struct TrailersLayer; + +impl Layer for TrailersLayer { + type Service = TrailersService; + + fn layer(&self, service: S) -> Self::Service { + TrailersService { service } + } +} + +#[derive(Debug, Clone)] +pub struct TrailersService { + service: S, +} + +impl Service> for TrailersService +where + S: Service, Response = Response>, + ResBody: http_body::Body, +{ + type Response = Response>; + type Error = S::Error; + type Future = WrappedFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + WrappedFuture { + inner: self.service.call(request), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedFuture { + #[pin] + inner: F, + } +} + +impl Future for WrappedFuture +where + F: Future, Error>>, + ResBody: http_body::Body, +{ + type Output = Result>, Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let result: Result, Error> = + ready!(self.as_mut().project().inner.poll(cx)); + + match result { + Ok(response) => { + Poll::Ready(Ok(response.map(|body| WrappedBody { inner: body }))) + } + Err(e) => Poll::Ready(Err(e)), + } + } +} + +pin_project! { + #[derive(Debug)] + pub struct WrappedBody { + #[pin] + inner: B, + } +} + +impl http_body::Body for WrappedBody { + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + self.as_mut().project().inner.poll_data(cx) + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + let result: Result, Self::Error> = + ready!(self.as_mut().project().inner.poll_trailers(cx)); + + let mut trailers = http::header::HeaderMap::new(); + trailers.insert("test-trailer", HeaderValue::from_static("trailer_val")); + + match result { + Ok(Some(mut existing)) => { + existing.extend(trailers.iter().map(|(k, v)| (k.clone(), v.clone()))); + Poll::Ready(Ok(Some(existing))) + } + Ok(None) => Poll::Ready(Ok(Some(trailers))), + Err(e) => Poll::Ready(Err(e)), + } + } + + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.inner.size_hint() + } +}