Skip to content

Commit

Permalink
feat: expose known_schema from FlightDataEncoder (#6688)
Browse files Browse the repository at this point in the history
With this change is now possible to inspect the encoded schema of the
encoded Flight data, which may differ from the input schema based on
dictionary hydration handling.

Fixes #6672
  • Loading branch information
nathanielc authored Nov 8, 2024
1 parent c36ff79 commit 9471bfb
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,41 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt};
/// # }
/// ```
///
/// # Example: Determining schema of encoded data
///
/// Encoding flight data may hydrate dictionaries, see [`DictionaryHandling`] for more information,
/// which changes the schema of the encoded data compared to the input record batches.
/// The fully hydrated schema can be accessed using the [`FlightDataEncoder::known_schema`] method
/// and explicitly informing the builder of the schema using [`FlightDataEncoderBuilder::with_schema`].
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
/// # async fn f() {
/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]);
/// # let batch = RecordBatch::try_from_iter(vec![
/// # ("a", Arc::new(c1) as ArrayRef)
/// # ])
/// # .expect("cannot create record batch");
/// use arrow_flight::encode::FlightDataEncoderBuilder;
///
/// // Get the schema of the input stream
/// let schema = batch.schema();
///
/// // Get an input stream of Result<RecordBatch, FlightError>
/// let input_stream = futures::stream::iter(vec![Ok(batch)]);
///
/// // Build a stream of `Result<FlightData>` (e.g. to return for do_get)
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// // Inform the builder of the input stream schema
/// .with_schema(schema)
/// .build(input_stream);
///
/// // Retrieve the schema of the encoded data
/// let encoded_schema = flight_data_stream.known_schema();
/// # }
/// ```
///
/// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get
/// [`FlightError`]: crate::error::FlightError
#[derive(Debug)]
Expand Down Expand Up @@ -287,6 +322,12 @@ impl FlightDataEncoder {
encoder
}

/// Report the schema of the encoded data when known.
/// A schema is known when provided via the [`FlightDataEncoderBuilder::with_schema`] method.
pub fn known_schema(&self) -> Option<SchemaRef> {
self.schema.clone()
}

/// Place the `FlightData` in the queue to send
fn queue_message(&mut self, mut data: FlightData) {
if let Some(descriptor) = self.descriptor.take() {
Expand Down Expand Up @@ -792,6 +833,53 @@ mod tests {
verify_flight_round_trip(vec![batch1, batch2]).await;
}

#[tokio::test]
async fn test_dictionary_hydration_known_schema() {
let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();

let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
"dict",
DataType::UInt16,
DataType::Utf8,
false,
)]));
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();

let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);

let encoder = FlightDataEncoderBuilder::default()
.with_schema(schema)
.build(stream);
let expected_schema =
Arc::new(Schema::new(vec![Field::new("dict", DataType::Utf8, false)]));
assert_eq!(Some(expected_schema), encoder.known_schema())
}

#[tokio::test]
async fn test_dictionary_resend_known_schema() {
let arr1: DictionaryArray<UInt16Type> = vec!["a", "a", "b"].into_iter().collect();
let arr2: DictionaryArray<UInt16Type> = vec!["c", "c", "d"].into_iter().collect();

let schema = Arc::new(Schema::new(vec![Field::new_dictionary(
"dict",
DataType::UInt16,
DataType::Utf8,
false,
)]));
let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();

let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);

let encoder = FlightDataEncoderBuilder::default()
.with_dictionary_handling(DictionaryHandling::Resend)
.with_schema(schema.clone())
.build(stream);
assert_eq!(Some(schema), encoder.known_schema())
}

#[tokio::test]
async fn test_multiple_dictionaries_resend() {
// Create a schema with two dictionary fields that have the same dict ID
Expand Down

0 comments on commit 9471bfb

Please sign in to comment.