diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index acfbd9b53030..ec4fe323b267 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -104,6 +104,41 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// # } /// ``` /// +/// # Example: Determining schema of encoded data +/// +/// Encoding flight data may hydrate dictionaries, see [`DictionaryHandling`] for more information, +/// which changes the schema of the encoded data compared to the input record batches. +/// The fully hydrated schema can be accessed using the [`FlightDataEncoder::known_schema`] method +/// and explicitly informing the builder of the schema using [`FlightDataEncoderBuilder::with_schema`]. +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get the schema of the input stream +/// let schema = batch.schema(); +/// +/// // Get an input stream of Result +/// let input_stream = futures::stream::iter(vec![Ok(batch)]); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// // Inform the builder of the input stream schema +/// .with_schema(schema) +/// .build(input_stream); +/// +/// // Retrieve the schema of the encoded data +/// let encoded_schema = flight_data_stream.known_schema(); +/// # } +/// ``` +/// /// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get /// [`FlightError`]: crate::error::FlightError #[derive(Debug)] @@ -287,6 +322,12 @@ impl FlightDataEncoder { encoder } + /// Report the schema of the encoded data when known. + /// A schema is known when provided via the [`FlightDataEncoderBuilder::with_schema`] method. + pub fn known_schema(&self) -> Option { + self.schema.clone() + } + /// Place the `FlightData` in the queue to send fn queue_message(&mut self, mut data: FlightData) { if let Some(descriptor) = self.descriptor.take() { @@ -792,6 +833,53 @@ mod tests { verify_flight_round_trip(vec![batch1, batch2]).await; } + #[tokio::test] + async fn test_dictionary_hydration_known_schema() { + let arr1: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default() + .with_schema(schema) + .build(stream); + let expected_schema = + Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)])); + assert_eq!(Some(expected_schema), encoder.known_schema()) + } + + #[tokio::test] + async fn test_dictionary_resend_known_schema() { + let arr1: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default() + .with_dictionary_handling(DictionaryHandling::Resend) + .with_schema(schema.clone()) + .build(stream); + assert_eq!(Some(schema), encoder.known_schema()) + } + #[tokio::test] async fn test_multiple_dictionaries_resend() { // Create a schema with two dictionary fields that have the same dict ID