Skip to content

Commit

Permalink
Fix generate_nested_dictionary_case integration test failure for Rust…
Browse files Browse the repository at this point in the history
… cases (#1636)

* Fix ipc nested dict

* Rename dictionaries_by_field.

* Fix a few more inconsistent names

* Rename a few more
  • Loading branch information
viirya authored May 8, 2022
1 parent e0a527b commit 6ad893c
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 69 deletions.
5 changes: 3 additions & 2 deletions arrow-flight/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Utilities to assist with reading and writing Arrow data as Flight messages
use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult};
use std::collections::HashMap;

use arrow::array::ArrayRef;
use arrow::datatypes::{Schema, SchemaRef};
Expand Down Expand Up @@ -49,7 +50,7 @@ pub fn flight_data_from_arrow_batch(
pub fn flight_data_to_arrow_batch(
data: &FlightData,
schema: SchemaRef,
dictionaries_by_field: &[Option<ArrayRef>],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
) -> Result<RecordBatch> {
// check that the data_header is a record batch message
let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| {
Expand All @@ -68,7 +69,7 @@ pub fn flight_data_to_arrow_batch(
&data.data_body,
batch,
schema,
dictionaries_by_field,
dictionaries_by_id,
None,
)
})?
Expand Down
86 changes: 43 additions & 43 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer {
/// - cast the 64-bit array to the appropriate data type
fn create_array(
nodes: &[ipc::FieldNode],
data_type: &DataType,
field: &Field,
data: &[u8],
buffers: &[ipc::Buffer],
dictionaries: &[Option<ArrayRef>],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
mut node_index: usize,
mut buffer_index: usize,
) -> Result<(ArrayRef, usize, usize)> {
use DataType::*;
let data_type = field.data_type();
let array = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
let array = create_primitive_array(
Expand Down Expand Up @@ -99,10 +100,10 @@ fn create_array(
buffer_index += 2;
let triple = create_array(
nodes,
list_field.data_type(),
list_field,
data,
buffers,
dictionaries,
dictionaries_by_id,
node_index,
buffer_index,
)?;
Expand All @@ -121,10 +122,10 @@ fn create_array(
buffer_index += 1;
let triple = create_array(
nodes,
list_field.data_type(),
list_field,
data,
buffers,
dictionaries,
dictionaries_by_id,
node_index,
buffer_index,
)?;
Expand All @@ -146,10 +147,10 @@ fn create_array(
for struct_field in struct_fields {
let triple = create_array(
nodes,
struct_field.data_type(),
struct_field,
data,
buffers,
dictionaries,
dictionaries_by_id,
node_index,
buffer_index,
)?;
Expand All @@ -173,15 +174,25 @@ fn create_array(
.iter()
.map(|buf| read_buffer(buf, data))
.collect();
let value_array = dictionaries[node_index].clone().unwrap();

let dict_id = field.dict_id().ok_or_else(|| {
ArrowError::IoError(format!("Field {} does not have dict id", field))
})?;

let value_array = dictionaries_by_id.get(&dict_id).ok_or_else(|| {
ArrowError::IoError(format!(
"Cannot find a dictionary batch with dict id: {}",
dict_id
))
})?;
node_index += 1;
buffer_index += 2;

create_dictionary_array(
index_node,
data_type,
&index_buffers[..],
value_array,
value_array.clone(),
)
}
Union(fields, mode) => {
Expand Down Expand Up @@ -209,10 +220,10 @@ fn create_array(
for field in fields {
let triple = create_array(
nodes,
field.data_type(),
field,
data,
buffers,
dictionaries,
dictionaries_by_id,
node_index,
buffer_index,
)?;
Expand Down Expand Up @@ -457,7 +468,7 @@ pub fn read_record_batch(
buf: &[u8],
batch: ipc::RecordBatch,
schema: SchemaRef,
dictionaries: &[Option<ArrayRef>],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
projection: Option<&[usize]>,
) -> Result<RecordBatch> {
let buffers = batch.buffers().ok_or_else(|| {
Expand All @@ -477,10 +488,10 @@ pub fn read_record_batch(
let field = &fields[index];
let triple = create_array(
field_nodes,
field.data_type(),
field,
buf,
buffers,
dictionaries,
dictionaries_by_id,
node_index,
buffer_index,
)?;
Expand All @@ -495,10 +506,10 @@ pub fn read_record_batch(
for field in schema.fields() {
let triple = create_array(
field_nodes,
field.data_type(),
field,
buf,
buffers,
dictionaries,
dictionaries_by_id,
node_index,
buffer_index,
)?;
Expand All @@ -511,12 +522,12 @@ pub fn read_record_batch(
}

/// Read the dictionary from the buffer and provided metadata,
/// updating the `dictionaries_by_field` with the resulting dictionary
/// updating the `dictionaries_by_id` with the resulting dictionary
pub fn read_dictionary(
buf: &[u8],
batch: ipc::DictionaryBatch,
schema: &Schema,
dictionaries_by_field: &mut [Option<ArrayRef>],
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
) -> Result<()> {
if batch.isDelta() {
return Err(ArrowError::IoError(
Expand Down Expand Up @@ -545,7 +556,7 @@ pub fn read_dictionary(
buf,
batch.data().unwrap(),
Arc::new(schema),
dictionaries_by_field,
dictionaries_by_id,
None,
)?;
Some(record_batch.column(0).clone())
Expand All @@ -556,16 +567,10 @@ pub fn read_dictionary(
ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string())
})?;

// for all fields with this dictionary id, update the dictionaries vector
// in the reader. Note that a dictionary batch may be shared between many fields.
// We don't currently record the isOrdered field. This could be general
// attributes of arrays.
for (i, field) in schema.all_fields().iter().enumerate() {
if field.dict_id() == Some(id) {
// Add (possibly multiple) array refs to the dictionaries array.
dictionaries_by_field[i] = Some(dictionary_values.clone());
}
}
// Add (possibly multiple) array refs to the dictionaries array.
dictionaries_by_id.insert(id, dictionary_values.clone());

Ok(())
}
Expand All @@ -592,7 +597,7 @@ pub struct FileReader<R: Read + Seek> {
/// Optional dictionaries for each schema field.
///
/// Dictionaries may be appended to in the streaming format.
dictionaries_by_field: Vec<Option<ArrayRef>>,
dictionaries_by_id: HashMap<i64, ArrayRef>,

/// Metadata version
metadata_version: ipc::MetadataVersion,
Expand Down Expand Up @@ -650,7 +655,7 @@ impl<R: Read + Seek> FileReader<R> {
let schema = ipc::convert::fb_to_schema(ipc_schema);

// Create an array of optional dictionary value arrays, one per field.
let mut dictionaries_by_field = vec![None; schema.all_fields().len()];
let mut dictionaries_by_id = HashMap::new();
if let Some(dictionaries) = footer.dictionaries() {
for block in dictionaries {
// read length from end of offset
Expand Down Expand Up @@ -683,12 +688,7 @@ impl<R: Read + Seek> FileReader<R> {
))?;
reader.read_exact(&mut buf)?;

read_dictionary(
&buf,
batch,
&schema,
&mut dictionaries_by_field,
)?;
read_dictionary(&buf, batch, &schema, &mut dictionaries_by_id)?;
}
t => {
return Err(ArrowError::IoError(format!(
Expand All @@ -713,7 +713,7 @@ impl<R: Read + Seek> FileReader<R> {
blocks: blocks.to_vec(),
current_block: 0,
total_blocks,
dictionaries_by_field,
dictionaries_by_id,
metadata_version: footer.version(),
projection,
})
Expand Down Expand Up @@ -795,7 +795,7 @@ impl<R: Read + Seek> FileReader<R> {
&buf,
batch,
self.schema(),
&self.dictionaries_by_field,
&self.dictionaries_by_id,
self.projection.as_ref().map(|x| x.0.as_ref()),

).map(Some)
Expand Down Expand Up @@ -840,7 +840,7 @@ pub struct StreamReader<R: Read> {
/// Optional dictionaries for each schema field.
///
/// Dictionaries may be appended to in the streaming format.
dictionaries_by_field: Vec<Option<ArrayRef>>,
dictionaries_by_id: HashMap<i64, ArrayRef>,

/// An indicator of whether the stream is complete.
///
Expand Down Expand Up @@ -884,7 +884,7 @@ impl<R: Read> StreamReader<R> {
let schema = ipc::convert::fb_to_schema(ipc_schema);

// Create an array of optional dictionary value arrays, one per field.
let dictionaries_by_field = vec![None; schema.all_fields().len()];
let dictionaries_by_id = HashMap::new();

let projection = match projection {
Some(projection_indices) => {
Expand All @@ -897,7 +897,7 @@ impl<R: Read> StreamReader<R> {
reader,
schema: Arc::new(schema),
finished: false,
dictionaries_by_field,
dictionaries_by_id,
projection,
})
}
Expand Down Expand Up @@ -971,7 +971,7 @@ impl<R: Read> StreamReader<R> {
let mut buf = vec![0; message.bodyLength() as usize];
self.reader.read_exact(&mut buf)?;

read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_field, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref())).map(Some)
}
ipc::MessageHeader::DictionaryBatch => {
let batch = message.header_as_dictionary_batch().ok_or_else(|| {
Expand All @@ -984,7 +984,7 @@ impl<R: Read> StreamReader<R> {
self.reader.read_exact(&mut buf)?;

read_dictionary(
&buf, batch, &self.schema, &mut self.dictionaries_by_field
&buf, batch, &self.schema, &mut self.dictionaries_by_id
)?;

// read the next message until we encounter a RecordBatch
Expand Down
24 changes: 11 additions & 13 deletions integration-testing/src/flight_client_scenarios/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::{read_json_file, ArrowFile};
use std::collections::HashMap;

use arrow::{
array::ArrayRef,
Expand Down Expand Up @@ -196,28 +197,25 @@ async fn consume_flight_location(
// first FlightData. Ignore this one.
let _schema_again = resp.next().await.unwrap();

let mut dictionaries_by_field = vec![None; schema.fields().len()];
let mut dictionaries_by_id = HashMap::new();

for (counter, expected_batch) in expected_data.iter().enumerate() {
let data = receive_batch_flight_data(
&mut resp,
schema.clone(),
&mut dictionaries_by_field,
)
.await
.unwrap_or_else(|| {
panic!(
let data =
receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries_by_id)
.await
.unwrap_or_else(|| {
panic!(
"Got fewer batches than expected, received so far: {} expected: {}",
counter,
expected_data.len(),
)
});
});

let metadata = counter.to_string().into_bytes();
assert_eq!(metadata, data.app_metadata);

let actual_batch =
flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_field)
flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_id)
.expect("Unable to convert flight data to Arrow batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
Expand Down Expand Up @@ -247,7 +245,7 @@ async fn consume_flight_location(
async fn receive_batch_flight_data(
resp: &mut Streaming<FlightData>,
schema: SchemaRef,
dictionaries_by_field: &mut [Option<ArrayRef>],
dictionaries_by_id: &mut HashMap<i64, ArrayRef>,
) -> Option<FlightData> {
let mut data = resp.next().await?.ok()?;
let mut message = arrow::ipc::root_as_message(&data.data_header[..])
Expand All @@ -260,7 +258,7 @@ async fn receive_batch_flight_data(
.header_as_dictionary_batch()
.expect("Error parsing dictionary"),
&schema,
dictionaries_by_field,
dictionaries_by_id,
)
.expect("Error reading dictionary");

Expand Down
Loading

0 comments on commit 6ad893c

Please sign in to comment.