Skip to content

Commit

Permalink
Add DictionaryArray::try_new() to create dictionaries from pre exis…
Browse files Browse the repository at this point in the history
…ting arrays (#1300)

* Add DictionaryArray::try_new()

* Update arrow/src/array/array_dictionary.rs

Co-authored-by: Liang-Chi Hsieh <[email protected]>

Co-authored-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
alamb and viirya authored Feb 15, 2022
1 parent 1ab95d5 commit 7d46ac1
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 113 deletions.
91 changes: 83 additions & 8 deletions arrow/src/array/array_dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use super::{
};
use crate::datatypes::ArrowNativeType;
use crate::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType, DataType};
use crate::error::Result;

/// A dictionary array where each element is a single value indexed by an integer key.
/// This is mostly used to represent strings or a limited set of primitive types as integers,
Expand All @@ -50,15 +51,31 @@ use crate::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType, DataType};
/// let array : DictionaryArray<Int8Type> = test.into_iter().collect();
/// assert_eq!(array.keys(), &Int8Array::from(vec![0, 0, 1, 2]));
/// ```
///
/// Example from existing arrays:
///
/// ```
/// use arrow::array::{DictionaryArray, Int8Array, StringArray};
/// use arrow::datatypes::Int8Type;
/// // You can form your own DictionaryArray by providing the
/// // values (dictionary) and keys (indexes into the dictionary):
/// let values = StringArray::from_iter_values(["a", "b", "c"]);
/// let keys = Int8Array::from_iter_values([0, 0, 1, 2]);
/// let array = DictionaryArray::<Int8Type>::try_new(&keys, &values).unwrap();
/// let expected: DictionaryArray::<Int8Type> = vec!["a", "a", "b", "c"]
/// .into_iter()
/// .collect();
/// assert_eq!(&array, &expected);
/// ```
pub struct DictionaryArray<K: ArrowPrimitiveType> {
/// Data of this dictionary. Note that this is _not_ compatible with the C Data interface,
/// as, in the current implementation, `values` below are the first child of this struct.
data: ArrayData,

/// The keys of this dictionary. These are constructed from the buffer and null bitmap
/// of `data`.
/// Also, note that these do not correspond to the true values of this array. Rather, they map
/// to the real values.
/// The keys of this dictionary. These are constructed from the
/// buffer and null bitmap of `data`. Also, note that these do
/// not correspond to the true values of this array. Rather, they
/// map to the real values.
keys: PrimitiveArray<K>,

/// Array of dictionary values (can by any DataType).
Expand All @@ -69,6 +86,27 @@ pub struct DictionaryArray<K: ArrowPrimitiveType> {
}

impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {
/// Attempt to create a new DictionaryArray with a specified keys
/// (indexes into the dictionary) and values (dictionary)
/// array. Returns an error if there are any keys that are outside
/// of the dictionary array.
pub fn try_new(keys: &PrimitiveArray<K>, values: &dyn Array) -> Result<Self> {
let dict_data_type = DataType::Dictionary(
Box::new(keys.data_type().clone()),
Box::new(values.data_type().clone()),
);

// Note: This does more work than necessary by rebuilding /
// revalidating all the data
let data = ArrayData::builder(dict_data_type)
.len(keys.len())
.add_buffer(keys.data().buffers()[0].clone())
.add_child_data(values.data().clone())
.build()?;

Ok(data.into())
}

/// Return an array view of the keys of this dictionary as a PrimitiveArray.
pub fn keys(&self) -> &PrimitiveArray<K> {
&self.keys
Expand Down Expand Up @@ -256,14 +294,14 @@ impl<T: ArrowPrimitiveType> fmt::Debug for DictionaryArray<T> {
mod tests {
use super::*;

use crate::{
array::Int16Array,
datatypes::{Int32Type, Int8Type, UInt32Type, UInt8Type},
};
use crate::{
array::Int16DictionaryArray, array::PrimitiveDictionaryBuilder,
datatypes::DataType,
};
use crate::{
array::{Int16Array, Int32Array},
datatypes::{Int32Type, Int8Type, UInt32Type, UInt8Type},
};
use crate::{buffer::Buffer, datatypes::ToByteSlice};

#[test]
Expand Down Expand Up @@ -422,4 +460,41 @@ mod tests {
.validate_full()
.expect("All null array has valid array data");
}

#[test]
fn test_try_new() {
let values: StringArray = [Some("foo"), Some("bar"), Some("baz")]
.into_iter()
.collect();
let keys: Int32Array = [Some(0), Some(2), None, Some(1)].into_iter().collect();

let array = DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
assert_eq!(array.keys().data_type(), &DataType::Int32);
assert_eq!(array.values().data_type(), &DataType::Utf8);
assert_eq!(
"DictionaryArray {keys: PrimitiveArray<Int32>\n[\n 0,\n 2,\n 0,\n 1,\n] values: StringArray\n[\n \"foo\",\n \"bar\",\n \"baz\",\n]}\n",
format!("{:?}", array)
);
}

#[test]
#[should_panic(
expected = "Value at position 1 out of bounds: 3 (should be in [0, 1])"
)]
fn test_try_new_index_too_large() {
let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect();
// dictionary only has 2 values, so offset 3 is out of bounds
let keys: Int32Array = [Some(0), Some(3)].into_iter().collect();
DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(
expected = "Value at position 0 out of bounds: -100 (should be in [0, 1])"
)]
fn test_try_new_index_too_small() {
let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect();
let keys: Int32Array = [Some(-100)].into_iter().collect();
DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
}
}
8 changes: 7 additions & 1 deletion arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! depend on dynamic casting of `Array`.
use super::{
Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray,
Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray, DictionaryArray,
FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray,
GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray,
StringOffsetSizeTrait, StructArray,
Expand Down Expand Up @@ -81,6 +81,12 @@ impl<T: ArrowPrimitiveType> PartialEq for PrimitiveArray<T> {
}
}

impl<K: ArrowPrimitiveType> PartialEq for DictionaryArray<K> {
fn eq(&self, other: &Self) -> bool {
equal(self.data(), other.data())
}
}

impl PartialEq for BooleanArray {
fn eq(&self, other: &BooleanArray) -> bool {
equal(self.data(), other.data())
Expand Down
148 changes: 44 additions & 104 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2699,7 +2699,6 @@ mod tests {

use super::*;
use crate::datatypes::Int8Type;
use crate::datatypes::ToByteSlice;
use crate::{array::Int32Array, array::Int64Array, datatypes::Field};

/// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output.
Expand Down Expand Up @@ -4664,41 +4663,15 @@ mod tests {
);
}

fn get_dict_arraydata(
keys: Buffer,
key_type: DataType,
value_data: ArrayData,
) -> ArrayData {
let value_type = value_data.data_type().clone();
let dict_data_type =
DataType::Dictionary(Box::new(key_type), Box::new(value_type));
ArrayData::builder(dict_data_type)
.len(3)
.add_buffer(keys)
.add_child_data(value_data)
.build()
.unwrap()
}

#[test]
fn test_eq_dyn_dictionary_i8_array() {
let key_type = DataType::Int8;
// Construct a value array
let value_data = ArrayData::builder(DataType::Int8)
.len(8)
.add_buffer(Buffer::from(
&[10_i8, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(),
))
.build()
.unwrap();
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = Buffer::from(&[2_i8, 3, 4].to_byte_slice());
let keys2 = Buffer::from(&[2_i8, 4, 4].to_byte_slice());
let dict_array1: DictionaryArray<Int8Type> = Int8DictionaryArray::from(
get_dict_arraydata(keys1, key_type.clone(), value_data.clone()),
);
let dict_array2: DictionaryArray<Int8Type> =
Int8DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data));
let keys1 = Int8Array::from_iter_values([2_i8, 3, 4]);
let keys2 = Int8Array::from_iter_values([2_i8, 4, 4]);
let dict_array1 = DictionaryArray::try_new(&keys1, &values).unwrap();
let dict_array2 = DictionaryArray::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Expand All @@ -4707,23 +4680,14 @@ mod tests {

#[test]
fn test_eq_dyn_dictionary_u64_array() {
let key_type = DataType::UInt64;
// Construct a value array
let value_data = ArrayData::builder(DataType::UInt64)
.len(8)
.add_buffer(Buffer::from(
&[10_u64, 11, 12, 13, 14, 15, 16, 17].to_byte_slice(),
))
.build()
.unwrap();
let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]);

let keys1 = Buffer::from(&[1_u64, 3, 4].to_byte_slice());
let keys2 = Buffer::from(&[2_u64, 3, 5].to_byte_slice());
let dict_array1: DictionaryArray<UInt64Type> = UInt64DictionaryArray::from(
get_dict_arraydata(keys1, key_type.clone(), value_data.clone()),
);
let dict_array2: DictionaryArray<UInt64Type> =
UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data));
let keys1 = UInt64Array::from_iter_values([1_u64, 3, 4]);
let keys2 = UInt64Array::from_iter_values([2_u64, 3, 5]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Expand Down Expand Up @@ -4757,29 +4721,17 @@ mod tests {

#[test]
fn test_eq_dyn_dictionary_binary_array() {
let key_type = DataType::UInt64;

// Construct a value array
let values: [u8; 12] = [
b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't',
];
let offsets: [i32; 4] = [0, 5, 5, 12];

// Array data: ["hello", "", "parquet"]
let value_data = ArrayData::builder(DataType::Binary)
.len(3)
.add_buffer(Buffer::from_slice_ref(&offsets))
.add_buffer(Buffer::from_slice_ref(&values))
.build()
.unwrap();
let values: BinaryArray = ["hello", "", "parquet"]
.into_iter()
.map(|b| Some(b.as_bytes()))
.collect();

let keys1 = Buffer::from(&[0_u64, 1, 2].to_byte_slice());
let keys2 = Buffer::from(&[0_u64, 2, 1].to_byte_slice());
let dict_array1: DictionaryArray<UInt64Type> = UInt64DictionaryArray::from(
get_dict_arraydata(keys1, key_type.clone(), value_data.clone()),
);
let dict_array2: DictionaryArray<UInt64Type> =
UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data));
let keys1 = UInt64Array::from_iter_values([0_u64, 1, 2]);
let keys2 = UInt64Array::from_iter_values([0_u64, 2, 1]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Expand All @@ -4791,18 +4743,14 @@ mod tests {

#[test]
fn test_eq_dyn_dictionary_interval_array() {
let key_type = DataType::UInt64;
let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);

let value_array = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]);
let value_data = value_array.data().clone();

let keys1 = Buffer::from(&[1_u64, 0, 3].to_byte_slice());
let keys2 = Buffer::from(&[2_u64, 0, 3].to_byte_slice());
let dict_array1: DictionaryArray<UInt64Type> = UInt64DictionaryArray::from(
get_dict_arraydata(keys1, key_type.clone(), value_data.clone()),
);
let dict_array2: DictionaryArray<UInt64Type> =
UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data));
let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
let keys2 = UInt64Array::from_iter_values([2_u64, 0, 3]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Expand All @@ -4811,18 +4759,14 @@ mod tests {

#[test]
fn test_eq_dyn_dictionary_date_array() {
let key_type = DataType::UInt64;

let value_array = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);
let value_data = value_array.data().clone();
let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]);

let keys1 = Buffer::from(&[1_u64, 0, 3].to_byte_slice());
let keys2 = Buffer::from(&[2_u64, 0, 3].to_byte_slice());
let dict_array1: DictionaryArray<UInt64Type> = UInt64DictionaryArray::from(
get_dict_arraydata(keys1, key_type.clone(), value_data.clone()),
);
let dict_array2: DictionaryArray<UInt64Type> =
UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data));
let keys1 = UInt64Array::from_iter_values([1_u64, 0, 3]);
let keys2 = UInt64Array::from_iter_values([2_u64, 0, 3]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Expand All @@ -4831,18 +4775,14 @@ mod tests {

#[test]
fn test_eq_dyn_dictionary_bool_array() {
let key_type = DataType::UInt64;

let value_array = BooleanArray::from(vec![true, false]);
let value_data = value_array.data().clone();

let keys1 = Buffer::from(&[1_u64, 1, 1].to_byte_slice());
let keys2 = Buffer::from(&[0_u64, 1, 0].to_byte_slice());
let dict_array1: DictionaryArray<UInt64Type> = UInt64DictionaryArray::from(
get_dict_arraydata(keys1, key_type.clone(), value_data.clone()),
);
let dict_array2: DictionaryArray<UInt64Type> =
UInt64DictionaryArray::from(get_dict_arraydata(keys2, key_type, value_data));
let values = BooleanArray::from(vec![true, false]);

let keys1 = UInt64Array::from_iter_values([1_u64, 1, 1]);
let keys2 = UInt64Array::from_iter_values([0_u64, 1, 0]);
let dict_array1 =
DictionaryArray::<UInt64Type>::try_new(&keys1, &values).unwrap();
let dict_array2 =
DictionaryArray::<UInt64Type>::try_new(&keys2, &values).unwrap();

let result = eq_dyn(&dict_array1, &dict_array2);
assert!(result.is_ok());
Expand Down

0 comments on commit 7d46ac1

Please sign in to comment.