Skip to content

Commit

Permalink
fix: recurse into Map datatype when hydrating dictionaries (#6645)
Browse files Browse the repository at this point in the history
When hydrating dictionaries for the FlightDataEncoder the Map data type
was assumed to not be nested. This change correctly recurses into the
Map field to hydrate any dictionaries within the map.

Fixes #6644
  • Loading branch information
nathanielc authored Oct 30, 2024
1 parent 56d4713 commit ad56c02
Showing 1 changed file with 161 additions and 0 deletions.
161 changes: 161 additions & 0 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,15 @@ fn prepare_field_for_flight(
.with_metadata(field.metadata().clone())
}
}
DataType::Map(inner, sorted) => Field::new(
field.name(),
DataType::Map(
prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
*sorted,
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.as_ref().clone(),
}
}
Expand Down Expand Up @@ -685,6 +694,7 @@ mod tests {
use arrow_cast::pretty::pretty_format_batches;
use arrow_ipc::MetadataVersion;
use arrow_schema::{UnionFields, UnionMode};
use builder::{GenericStringBuilder, MapBuilder};
use std::collections::HashMap;

use super::*;
Expand Down Expand Up @@ -1276,6 +1286,157 @@ mod tests {
verify_flight_round_trip(vec![batch1, batch2, batch3]).await;
}

#[tokio::test]
async fn test_dictionary_map_hydration() {
let mut builder = MapBuilder::new(
None,
StringDictionaryBuilder::<UInt16Type>::new(),
StringDictionaryBuilder::<UInt16Type>::new(),
);

// {"k1":"a","k2":null,"k3":"b"}
builder.keys().append_value("k1");
builder.values().append_value("a");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("b");
builder.append(true).unwrap();

let arr1 = builder.finish();

// {"k1":"c","k2":null,"k3":"d"}
builder.keys().append_value("k1");
builder.values().append_value("c");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("d");
builder.append(true).unwrap();

let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new_map(
"dict_map",
"entries",
Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
false,
false,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();

let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]);

let encoder = FlightDataEncoderBuilder::default().build(stream);

let mut decoder = FlightDataDecoder::new(encoder);
let expected_schema = Schema::new(vec![Field::new_map(
"dict_map",
"entries",
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Utf8, true),
false,
false,
)]);

let expected_schema = Arc::new(expected_schema);

// Builder without dictionary fields
let mut builder = MapBuilder::new(
None,
GenericStringBuilder::<i32>::new(),
GenericStringBuilder::<i32>::new(),
);

// {"k1":"a","k2":null,"k3":"b"}
builder.keys().append_value("k1");
builder.values().append_value("a");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("b");
builder.append(true).unwrap();

let arr1 = builder.finish();

// {"k1":"c","k2":null,"k3":"d"}
builder.keys().append_value("k1");
builder.values().append_value("c");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("d");
builder.append(true).unwrap();

let arr2 = builder.finish();

let mut expected_arrays = vec![arr1, arr2].into_iter();

while let Some(decoded) = decoder.next().await {
let decoded = decoded.unwrap();
match decoded.payload {
DecodedPayload::None => {}
DecodedPayload::Schema(s) => assert_eq!(s, expected_schema),
DecodedPayload::RecordBatch(b) => {
assert_eq!(b.schema(), expected_schema);
let expected_array = expected_arrays.next().unwrap();
let map_array =
downcast_array::<MapArray>(b.column_by_name("dict_map").unwrap());

assert_eq!(map_array, expected_array);
}
}
}
}

#[tokio::test]
async fn test_dictionary_map_resend() {
let mut builder = MapBuilder::new(
None,
StringDictionaryBuilder::<UInt16Type>::new(),
StringDictionaryBuilder::<UInt16Type>::new(),
);

// {"k1":"a","k2":null,"k3":"b"}
builder.keys().append_value("k1");
builder.values().append_value("a");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("b");
builder.append(true).unwrap();

let arr1 = builder.finish();

// {"k1":"c","k2":null,"k3":"d"}
builder.keys().append_value("k1");
builder.values().append_value("c");
builder.keys().append_value("k2");
builder.values().append_null();
builder.keys().append_value("k3");
builder.values().append_value("d");
builder.append(true).unwrap();

let arr2 = builder.finish();

let schema = Arc::new(Schema::new(vec![Field::new_map(
"dict_map",
"entries",
Field::new_dictionary("keys", DataType::UInt16, DataType::Utf8, false),
Field::new_dictionary("values", DataType::UInt16, DataType::Utf8, true),
false,
false,
)]));

let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap();
let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap();

verify_flight_round_trip(vec![batch1, batch2]).await;
}

async fn verify_flight_round_trip(mut batches: Vec<RecordBatch>) {
let expected_schema = batches.first().unwrap().schema();

Expand Down

0 comments on commit ad56c02

Please sign in to comment.