Skip to content

Commit

Permalink
feat: expose DoGet response headers & trailers
Browse files Browse the repository at this point in the history
  • Loading branch information
crepererum committed Aug 23, 2023
1 parent 90449ff commit 852651e
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 18 deletions.
3 changes: 3 additions & 0 deletions arrow-flight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ cli = ["arrow-cast/prettyprint", "clap", "tracing-log", "tracing-subscriber", "t
[dev-dependencies]
arrow-cast = { workspace = true, features = ["prettyprint"] }
assert_cmd = "2.0.8"
http = "0.2.9"
http-body = "0.4.5"
pin-project-lite = "0.2"
tempfile = "3.3"
tokio-stream = { version = "0.1", features = ["net"] }
tower = "0.4.13"
Expand Down
20 changes: 9 additions & 11 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
use std::task::Poll;

use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, PutResult, Ticket,
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient,
trailers::extract_trailers, Action, ActionType, Criteria, Empty, FlightData,
FlightDescriptor, FlightInfo, HandshakeRequest, PutResult, Ticket,
};
use arrow_schema::Schema;
use bytes::Bytes;
Expand Down Expand Up @@ -204,16 +204,14 @@ impl FlightClient {
pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
let request = self.make_request(ticket);

let response_stream = self
.inner
.do_get(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts();
let (response_stream, trailers) = extract_trailers(response_stream);

Ok(FlightRecordBatchStream::new_from_flight_data(
response_stream,
))
response_stream.map_err(FlightError::Tonic),
)
.with_headers(md)
.with_trailers(trailers))
}

/// Make a `GetFlightInfo` call to the server with the provided
Expand Down
43 changes: 41 additions & 2 deletions arrow-flight/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::{utils::flight_data_to_arrow_batch, FlightData};
use crate::{trailers::LazyTrailers, utils::flight_data_to_arrow_batch, FlightData};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
use arrow_schema::{Schema, SchemaRef};
Expand All @@ -24,6 +24,7 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt};
use std::{
collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll,
};
use tonic::metadata::MetadataMap;

use crate::error::{FlightError, Result};

Expand Down Expand Up @@ -82,13 +83,23 @@ use crate::error::{FlightError, Result};
/// ```
#[derive(Debug)]
pub struct FlightRecordBatchStream {
/// Optional grpc header metadata.
headers: MetadataMap,

/// Optional grpc trailer metadata.
trailers: Option<LazyTrailers>,

inner: FlightDataDecoder,
}

impl FlightRecordBatchStream {
/// Create a new [`FlightRecordBatchStream`] from a decoded stream
pub fn new(inner: FlightDataDecoder) -> Self {
Self { inner }
Self {
inner,
headers: MetadataMap::default(),
trailers: None,
}
}

/// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`]
Expand All @@ -98,9 +109,36 @@ impl FlightRecordBatchStream {
{
Self {
inner: FlightDataDecoder::new(inner),
headers: MetadataMap::default(),
trailers: None,
}
}

/// Record response headers.
pub fn with_headers(self, headers: MetadataMap) -> Self {
Self { headers, ..self }
}

/// Record response trailers.
pub fn with_trailers(self, trailers: LazyTrailers) -> Self {
Self {
trailers: Some(trailers),
..self
}
}

/// Headers attached to this stream.
pub fn headers(&self) -> &MetadataMap {
&self.headers
}

/// Trailers attached to this stream.
///
/// This is only filled when the entire stream was consumed.
pub fn trailers(&self) -> Option<MetadataMap> {
self.trailers.as_ref().and_then(|trailers| trailers.get())
}

/// Has a message defining the schema been received yet?
#[deprecated = "use schema().is_some() instead"]
pub fn got_schema(&self) -> bool {
Expand All @@ -117,6 +155,7 @@ impl FlightRecordBatchStream {
self.inner
}
}

impl futures::Stream for FlightRecordBatchStream {
type Item = Result<RecordBatch>;

Expand Down
3 changes: 3 additions & 0 deletions arrow-flight/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ pub use gen::Result;
pub use gen::SchemaResult;
pub use gen::Ticket;

/// Helper to extract HTTP/gRPC trailers from a tonic stream.
mod trailers;

pub mod utils;

#[cfg(feature = "flight-sql-experimental")]
Expand Down
92 changes: 92 additions & 0 deletions arrow-flight/src/trailers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use futures::{ready, FutureExt, Stream, StreamExt};
use tonic::{metadata::MetadataMap, Status, Streaming};

/// Extract trailers from [`Streaming`] [tonic] response.
pub fn extract_trailers<T>(s: Streaming<T>) -> (ExtractTrailersStream<T>, LazyTrailers) {
let trailers: SharedTrailers = Default::default();
let stream = ExtractTrailersStream {
inner: s,
trailers: Arc::clone(&trailers),
};
let lazy_trailers = LazyTrailers { trailers };
(stream, lazy_trailers)
}

type SharedTrailers = Arc<Mutex<Option<MetadataMap>>>;

/// [Stream] that stores the gRPC trailers into [`LazyTrailers`].
///
/// See [`extract_trailers`] for construction.
#[derive(Debug)]
pub struct ExtractTrailersStream<T> {
inner: Streaming<T>,
trailers: SharedTrailers,
}

impl<T> Stream for ExtractTrailersStream<T> {
type Item = Result<T, Status>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let res = ready!(self.inner.poll_next_unpin(cx));

if res.is_none() {
// stream exhausted => trailers should available
if let Some(trailers) = self
.inner
.trailers()
.now_or_never()
.and_then(|res| res.ok())
.flatten()
{
*self.trailers.lock().expect("poisoned") = Some(trailers);
}
}

Poll::Ready(res)
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

/// gRPC trailers that are extracted by [`ExtractTrailersStream`].
///
/// See [`extract_trailers`] for construction.
#[derive(Debug)]
pub struct LazyTrailers {
trailers: SharedTrailers,
}

impl LazyTrailers {
/// gRPC trailers that are known at the end of a stream.
pub fn get(&self) -> Option<MetadataMap> {
self.trailers.lock().expect("poisoned").clone()
}
}
31 changes: 27 additions & 4 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
mod common {
pub mod server;
pub mod trailers_layer;
}
use arrow_array::{RecordBatch, UInt64Array};
use arrow_flight::{
Expand All @@ -28,7 +29,7 @@ use arrow_flight::{
};
use arrow_schema::{DataType, Field, Schema};
use bytes::Bytes;
use common::server::TestFlightServer;
use common::{server::TestFlightServer, trailers_layer::TrailersLayer};
use futures::{Future, StreamExt, TryStreamExt};
use tokio::{net::TcpListener, task::JoinHandle};
use tonic::{
Expand Down Expand Up @@ -158,18 +159,39 @@ async fn test_do_get() {

let response = vec![Ok(batch.clone())];
test_server.set_do_get_response(response);
let response_stream = client
let mut response_stream = client
.do_get(ticket.clone())
.await
.expect("error making request");

assert_eq!(
response_stream
.headers()
.get("test-resp-header")
.expect("header exists")
.to_str()
.unwrap(),
"some_val",
);

let expected_response = vec![batch];
let response: Vec<_> = response_stream
let response: Vec<_> = (&mut response_stream)
.try_collect()
.await
.expect("Error streaming data");

assert_eq!(response, expected_response);

assert_eq!(
response_stream
.trailers()
.expect("stream exhausted")
.get("test-trailer")
.expect("trailer exists")
.to_str()
.unwrap(),
"trailer_val",
);

assert_eq!(test_server.take_do_get_request(), Some(ticket));
ensure_metadata(&client, &test_server);
})
Expand Down Expand Up @@ -932,6 +954,7 @@ impl TestFixture {

let serve_future = tonic::transport::Server::builder()
.timeout(server_timeout)
.layer(TrailersLayer)
.add_service(test_server.service())
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
Expand Down
6 changes: 5 additions & 1 deletion arrow-flight/tests/common/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ impl FlightService for TestFlightServer {
.build(batch_stream)
.map_err(Into::into);

Ok(Response::new(stream.boxed()))
let mut resp = Response::new(stream.boxed());
resp.metadata_mut()
.insert("test-resp-header", "some_val".parse().unwrap());

Ok(resp)
}

async fn do_put(
Expand Down
Loading

0 comments on commit 852651e

Please sign in to comment.