From 513389d99539a9667dbcae55ac67cae538109e0b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 6 May 2022 16:27:17 -0700 Subject: [PATCH] Receive schema from flight data. --- .../integration_test.rs | 47 ++++++++++++------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index 7b32c571383b..4632683cc5fc 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -31,6 +31,7 @@ use arrow_flight::{ use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; +use arrow::datatypes::Schema; use std::sync::Arc; type Error = Box; @@ -60,7 +61,7 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { batches.clone(), ) .await?; - verify_data(client, descriptor, schema, &batches).await?; + verify_data(client, descriptor, &batches).await?; Ok(()) } @@ -143,7 +144,6 @@ async fn send_batch( async fn verify_data( mut client: Client, descriptor: FlightDescriptor, - expected_schema: SchemaRef, expected_data: &[RecordBatch], ) -> Result { let resp = client.get_flight_info(Request::new(descriptor)).await?; @@ -163,13 +163,7 @@ async fn verify_data( "No locations returned from Flight server", ); for location in endpoint.location { - consume_flight_location( - location, - ticket.clone(), - expected_data, - expected_schema.clone(), - ) - .await?; + consume_flight_location(location, ticket.clone(), expected_data).await?; } } @@ -180,7 +174,6 @@ async fn consume_flight_location( location: Location, ticket: Ticket, expected_data: &[RecordBatch], - schema: SchemaRef, ) -> Result { let mut location = location; // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs @@ -192,16 +185,17 @@ async fn consume_flight_location( let resp = client.do_get(ticket).await?; let mut resp = resp.into_inner(); - // We already have the schema from the FlightInfo, but the server sends it again as the - // first FlightData. Ignore this one. - let _schema_again = resp.next().await.unwrap(); + let flight_schema = receive_schema_flight_data(&mut resp) + .await + .unwrap_or_else(|| panic!("Failed to receive flight schema")); + let actual_schema = Arc::new(flight_schema); - let mut dictionaries_by_field = vec![None; schema.fields().len()]; + let mut dictionaries_by_field = vec![None; actual_schema.fields().len()]; for (counter, expected_batch) in expected_data.iter().enumerate() { let data = receive_batch_flight_data( &mut resp, - schema.clone(), + actual_schema.clone(), &mut dictionaries_by_field, ) .await @@ -216,9 +210,12 @@ async fn consume_flight_location( let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); - let actual_batch = - flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_field) - .expect("Unable to convert flight data to Arrow batch"); + let actual_batch = flight_data_to_arrow_batch( + &data, + actual_schema.clone(), + &dictionaries_by_field, + ) + .expect("Unable to convert flight data to Arrow batch"); assert_eq!(expected_batch.schema(), actual_batch.schema()); assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); @@ -244,6 +241,20 @@ async fn consume_flight_location( Ok(()) } +async fn receive_schema_flight_data(resp: &mut Streaming) -> Option { + let data = resp.next().await?.ok()?; + let message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing message"); + + // message header is a Schema, so read it + let ipc_schema: ipc::Schema = message + .header_as_schema() + .expect("Unable to read IPC message as schema"); + let schema = ipc::convert::fb_to_schema(ipc_schema); + + Some(schema) +} + async fn receive_batch_flight_data( resp: &mut Streaming, schema: SchemaRef,