Skip to content

Commit

Permalink
Implement comparisons on nested data types such that distinct/except …
Browse files Browse the repository at this point in the history
…would work (#11117)

This relies on newer functionality in arrow 52 and allows
DataFrame.except() to properly work on schemas with structs and lists

Closes #10749
  • Loading branch information
rtyler authored Jun 27, 2024
1 parent f58df32 commit d2ff218
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
76 changes: 76 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3445,6 +3445,82 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_except_nested_struct() -> Result<()> {
use arrow::array::StructArray;

let nested_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, true),
Field::new("lat", DataType::Int32, true),
Field::new("long", DataType::Int32, true),
]));
let schema = Arc::new(Schema::new(vec![
Field::new("value", DataType::Int32, true),
Field::new(
"nested",
DataType::Struct(nested_schema.fields.clone()),
true,
),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();

let updated_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])),
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("id", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("lat", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
(
Arc::new(Field::new("long", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
),
])),
],
)
.unwrap();

let ctx = SessionContext::new();
let before = ctx.read_batch(batch).expect("Failed to make DataFrame");
let after = ctx
.read_batch(updated_batch)
.expect("Failed to make DataFrame");

let diff = before
.except(after)
.expect("Failed to except")
.collect()
.await?;
assert_eq!(diff.len(), 1);
Ok(())
}

#[tokio::test]
async fn nested_explain_should_fail() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
14 changes: 13 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
UInt64Array,
};
use arrow::buffer::NullBuffer;
use arrow::compute::kernels::cmp::{eq, not_distinct};
use arrow::compute::{and, concat_batches, take, FilterBuilder};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util;
use arrow_array::cast::downcast_array;
use arrow_schema::ArrowError;
use arrow_ord::ord::make_comparator;
use arrow_schema::{ArrowError, SortOptions};
use datafusion_common::utils::memory::estimate_memory_size;
use datafusion_common::{
internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError,
Expand Down Expand Up @@ -1210,6 +1212,16 @@ fn eq_dyn_null(
right: &dyn Array,
null_equals_null: bool,
) -> Result<BooleanArray, ArrowError> {
// Nested datatypes cannot use the underlying not_distinct function and must use a special
// implementation
// <https://github.com/apache/datafusion/issues/10749>
if left.data_type().is_nested() && null_equals_null {
let cmp = make_comparator(left, right, SortOptions::default())?;
let len = left.len().min(right.len());
let values = (0..len).map(|i| cmp(i, i).is_eq()).collect();
let nulls = NullBuffer::union(left.nulls(), right.nulls());
return Ok(BooleanArray::new(values, nulls));
}
match (left.data_type(), right.data_type()) {
_ if null_equals_null => not_distinct(&left, &right),
_ => eq(&left, &right),
Expand Down

0 comments on commit d2ff218

Please sign in to comment.