diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 2df55545218f..8793f7834bfb 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -19,7 +19,7 @@ use std::task::Poll; use crate::{ decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, - trailers::extract_trailers, Action, ActionType, Criteria, Empty, FlightData, + trailers::extract_lazy_trailers, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, PutResult, Ticket, }; use arrow_schema::Schema; @@ -205,7 +205,7 @@ impl FlightClient { let request = self.make_request(ticket); let (md, response_stream, _ext) = self.inner.do_get(request).await?.into_parts(); - let (response_stream, trailers) = extract_trailers(response_stream); + let (response_stream, trailers) = extract_lazy_trailers(response_stream); Ok(FlightRecordBatchStream::new_from_flight_data( response_stream.map_err(FlightError::Tonic), diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs index aba652ad64b7..1c6578e6cd3e 100644 --- a/arrow-flight/src/trailers.rs +++ b/arrow-flight/src/trailers.rs @@ -24,8 +24,13 @@ use std::{ use futures::{ready, FutureExt, Stream, StreamExt}; use tonic::{metadata::MetadataMap, Status, Streaming}; -/// Extract trailers from [`Streaming`] [tonic] response. -pub fn extract_trailers(s: Streaming) -> (ExtractTrailersStream, LazyTrailers) { +/// Extract [`LazyTrailers`] from [`Streaming`] [tonic] response. +/// +/// Note that [`LazyTrailers`] has inner mutability and will only hold actual data after [`ExtractTrailersStream`] is +/// fully consumed (dropping it is not required though). +pub fn extract_lazy_trailers( + s: Streaming, +) -> (ExtractTrailersStream, LazyTrailers) { let trailers: SharedTrailers = Default::default(); let stream = ExtractTrailersStream { inner: s, diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 2d1dc76ae74b..1b9891e121fa 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -174,6 +174,9 @@ async fn test_do_get() { "some_val", ); + // trailers are not available before stream exhaustion + assert!(response_stream.trailers().is_none()); + let expected_response = vec![batch]; let response: Vec<_> = (&mut response_stream) .try_collect()