Skip to content

Commit

Permalink
Add Round trip tests for Array <--> ScalarValue (apache#13777)
Browse files Browse the repository at this point in the history
* Add Round trip tests for Array <--> ScalarValue

* String dictionary test

* remove unecessary value

* Improve comments
  • Loading branch information
alamb authored Dec 16, 2024
1 parent 668984e commit 59410ea
Showing 1 changed file with 191 additions and 0 deletions.
191 changes: 191 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3892,6 +3892,7 @@ mod tests {
use arrow::compute::{is_null, kernels};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use arrow_array::types::Float64Type;
use arrow_buffer::{Buffer, NullBuffer};
use arrow_schema::Fields;
use chrono::NaiveDate;
Expand Down Expand Up @@ -5554,6 +5555,196 @@ mod tests {
assert_eq!(&array, &expected);
}

#[test]
fn round_trip() {
// Each array type should be able to round tripped through a scalar
let cases: Vec<ArrayRef> = vec![
// int
Arc::new(Int8Array::from(vec![Some(1), None, Some(3)])),
Arc::new(Int16Array::from(vec![Some(1), None, Some(3)])),
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])),
Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])),
Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])),
Arc::new(UInt16Array::from(vec![Some(1), None, Some(3)])),
Arc::new(UInt32Array::from(vec![Some(1), None, Some(3)])),
Arc::new(UInt64Array::from(vec![Some(1), None, Some(3)])),
// bool
Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)])),
// float
Arc::new(Float32Array::from(vec![Some(1.0), None, Some(3.0)])),
Arc::new(Float64Array::from(vec![Some(1.0), None, Some(3.0)])),
// string array
Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])),
Arc::new(LargeStringArray::from(vec![Some("foo"), None, Some("bar")])),
Arc::new(StringViewArray::from(vec![Some("foo"), None, Some("bar")])),
// string dictionary
{
let mut builder = StringDictionaryBuilder::<Int32Type>::new();
builder.append("foo").unwrap();
builder.append_null();
builder.append("bar").unwrap();
Arc::new(builder.finish())
},
// binary array
Arc::new(BinaryArray::from_iter(vec![
Some(b"foo"),
None,
Some(b"bar"),
])),
Arc::new(LargeBinaryArray::from_iter(vec![
Some(b"foo"),
None,
Some(b"bar"),
])),
Arc::new(BinaryViewArray::from_iter(vec![
Some(b"foo"),
None,
Some(b"bar"),
])),
// timestamp
Arc::new(TimestampSecondArray::from(vec![Some(1), None, Some(3)])),
Arc::new(TimestampMillisecondArray::from(vec![
Some(1),
None,
Some(3),
])),
Arc::new(TimestampMicrosecondArray::from(vec![
Some(1),
None,
Some(3),
])),
Arc::new(TimestampNanosecondArray::from(vec![Some(1), None, Some(3)])),
// timestamp with timezone
Arc::new(
TimestampSecondArray::from(vec![Some(1), None, Some(3)])
.with_timezone_opt(Some("UTC")),
),
Arc::new(
TimestampMillisecondArray::from(vec![Some(1), None, Some(3)])
.with_timezone_opt(Some("UTC")),
),
Arc::new(
TimestampMicrosecondArray::from(vec![Some(1), None, Some(3)])
.with_timezone_opt(Some("UTC")),
),
Arc::new(
TimestampNanosecondArray::from(vec![Some(1), None, Some(3)])
.with_timezone_opt(Some("UTC")),
),
// date
Arc::new(Date32Array::from(vec![Some(1), None, Some(3)])),
Arc::new(Date64Array::from(vec![Some(1), None, Some(3)])),
// time
Arc::new(Time32SecondArray::from(vec![Some(1), None, Some(3)])),
Arc::new(Time32MillisecondArray::from(vec![Some(1), None, Some(3)])),
Arc::new(Time64MicrosecondArray::from(vec![Some(1), None, Some(3)])),
Arc::new(Time64NanosecondArray::from(vec![Some(1), None, Some(3)])),
// null array
Arc::new(NullArray::new(3)),
// dense union
/* Dense union fails due to https://github.com/apache/datafusion/issues/13762
{
let mut builder = UnionBuilder::new_dense();
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Float64Type>("b", 3.4).unwrap();
Arc::new(builder.build().unwrap())
}
*/
// sparse union
{
let mut builder = UnionBuilder::new_sparse();
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Float64Type>("b", 3.4).unwrap();
Arc::new(builder.build().unwrap())
},
// list array
{
let values_builder = StringBuilder::new();
let mut builder = ListBuilder::new(values_builder);
// [A, B]
builder.values().append_value("A");
builder.values().append_value("B");
builder.append(true);
// [ ] (empty list)
builder.append(true);
// Null
builder.values().append_value("?"); // irrelevant
builder.append(false);
Arc::new(builder.finish())
},
// large list array
{
let values_builder = StringBuilder::new();
let mut builder = LargeListBuilder::new(values_builder);
// [A, B]
builder.values().append_value("A");
builder.values().append_value("B");
builder.append(true);
// [ ] (empty list)
builder.append(true);
// Null
builder.append(false);
Arc::new(builder.finish())
},
// fixed size list array
{
let values_builder = Int32Builder::new();
let mut builder = FixedSizeListBuilder::new(values_builder, 3);

// [[0, 1, 2], null, [3, null, 5]
builder.values().append_value(0);
builder.values().append_value(1);
builder.values().append_value(2);
builder.append(true);
builder.values().append_null();
builder.values().append_null();
builder.values().append_null();
builder.append(false);
builder.values().append_value(3);
builder.values().append_null();
builder.values().append_value(5);
builder.append(true);
Arc::new(builder.finish())
},
// map
{
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::with_capacity(4);

let mut builder = MapBuilder::new(None, string_builder, int_builder);
// {"joe": 1}
builder.keys().append_value("joe");
builder.values().append_value(1);
builder.append(true).unwrap();
// {}
builder.append(true).unwrap();
// null
builder.append(false).unwrap();

Arc::new(builder.finish())
},
];

for arr in cases {
round_trip_through_scalar(arr);
}
}

/// for each row in `arr`:
/// 1. convert to a `ScalarValue`
/// 2. Convert `ScalarValue` back to an `ArrayRef`
/// 3. Compare the original array (sliced) and new array for equality
fn round_trip_through_scalar(arr: ArrayRef) {
for i in 0..arr.len() {
// convert Scalar --> Array
let scalar = ScalarValue::try_from_array(&arr, i).unwrap();
let array = scalar.to_array_of_size(1).unwrap();
assert_eq!(array.len(), 1);
assert_eq!(array.data_type(), arr.data_type());
assert_eq!(array.as_ref(), arr.slice(i, 1).as_ref());
}
}

#[test]
fn test_scalar_union_sparse() {
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
Expand Down

0 comments on commit 59410ea

Please sign in to comment.