From fdb8c68aa697942a0d93b41e9bdf6011af48eab4 Mon Sep 17 00:00:00 2001 From: Max Drach Date: Mon, 13 Dec 2021 23:00:18 -0800 Subject: [PATCH] Support parquet read from dictionary-encoded nonoptional pages (#683) --- src/io/parquet/read/fixed_size_binary.rs | 33 ++++++++++++++++++++++++ tests/it/io/parquet/read.rs | 15 +++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/io/parquet/read/fixed_size_binary.rs b/src/io/parquet/read/fixed_size_binary.rs index ddd6e02beb9..01973e8db7e 100644 --- a/src/io/parquet/read/fixed_size_binary.rs +++ b/src/io/parquet/read/fixed_size_binary.rs @@ -70,6 +70,31 @@ pub(crate) fn read_dict_buffer( } } +/// Assumptions: No rep levels +pub(crate) fn read_dict_required( + indices_buffer: &[u8], + additional: usize, + size: usize, + dict: &FixedLenByteArrayPageDict, + values: &mut MutableBuffer, + validity: &mut MutableBitmap, +) { + let dict_values = dict.values(); + + // SPEC: Data page format: the bit width used to encode the entry ids stored as 1 byte (max bit width = 32), + // SPEC: followed by the values encoded using RLE/Bit packed described above (with the given bit width). + let bit_width = indices_buffer[0]; + let indices_buffer = &indices_buffer[1..]; + + let indices = hybrid_rle::HybridRleDecoder::new(indices_buffer, bit_width as u32, additional); + + for index in indices { + let index = index as usize; + values.extend_from_slice(&dict_values[index * size..(index + 1) * size]); + } + validity.extend_constant(additional * size, true); +} + pub(crate) fn read_optional( validity_buffer: &[u8], values_buffer: &[u8], @@ -217,6 +242,14 @@ pub(crate) fn extend_from_page( values, validity, ), + (Encoding::PlainDictionary, Some(dict), false) => read_dict_required( + values_buffer, + additional, + size, + dict.as_any().downcast_ref().unwrap(), + values, + validity, + ), (Encoding::Plain, _, true) => read_optional( validity_buffer, values_buffer, diff --git a/tests/it/io/parquet/read.rs b/tests/it/io/parquet/read.rs index 98dddd1968c..04a2e1a9cef 100644 --- a/tests/it/io/parquet/read.rs +++ b/tests/it/io/parquet/read.rs @@ -313,6 +313,11 @@ fn v2_decimal_9_required() -> Result<()> { test_pyarrow_integration(6, 2, "basic", false, true, None) } +#[test] +fn v2_decimal_9_required_dict() -> Result<()> { + test_pyarrow_integration(6, 2, "basic", true, true, None) +} + #[test] fn v2_decimal_18_nullable() -> Result<()> { test_pyarrow_integration(8, 2, "basic", false, false, None) @@ -323,6 +328,11 @@ fn v2_decimal_18_required() -> Result<()> { test_pyarrow_integration(7, 2, "basic", false, true, None) } +#[test] +fn v2_decimal_18_required_dict() -> Result<()> { + test_pyarrow_integration(7, 2, "basic", true, true, None) +} + #[test] fn v2_decimal_26_nullable() -> Result<()> { test_pyarrow_integration(9, 2, "basic", false, false, None) @@ -333,6 +343,11 @@ fn v2_decimal_26_required() -> Result<()> { test_pyarrow_integration(8, 2, "basic", false, true, None) } +#[test] +fn v2_decimal_26_required_dict() -> Result<()> { + test_pyarrow_integration(8, 2, "basic", true, true, None) +} + #[test] fn v1_struct_optional() -> Result<()> { test_pyarrow_integration(0, 1, "struct", false, false, None)