From d2ff2189dfb8b4624ae2c08846cd713871b37d8c Mon Sep 17 00:00:00 2001 From: "R. Tyler Croy" Date: Thu, 27 Jun 2024 14:39:14 -0700 Subject: [PATCH] Implement comparisons on nested data types such that distinct/except would work (#11117) This relies on newer functionality in arrow 52 and allows DataFrame.except() to properly work on schemas with structs and lists Closes #10749 --- datafusion/core/src/dataframe/mod.rs | 76 +++++++++++++++++++ .../physical-plan/src/joins/hash_join.rs | 14 +++- 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index e1fc8273e6ff..86e510969b33 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -3445,6 +3445,82 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_except_nested_struct() -> Result<()> { + use arrow::array::StructArray; + + let nested_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("lat", DataType::Int32, true), + Field::new("long", DataType::Int32, true), + ])); + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, true), + Field::new( + "nested", + DataType::Struct(nested_schema.fields.clone()), + true, + ), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("lat", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("long", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ])), + ], + ) + .unwrap(); + + let updated_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])), + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("lat", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("long", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ])), + ], + ) + .unwrap(); + + let ctx = SessionContext::new(); + let before = ctx.read_batch(batch).expect("Failed to make DataFrame"); + let after = ctx + .read_batch(updated_batch) + .expect("Failed to make DataFrame"); + + let diff = before + .except(after) + .expect("Failed to except") + .collect() + .await?; + assert_eq!(diff.len(), 1); + Ok(()) + } + #[tokio::test] async fn nested_explain_should_fail() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 5353092d5c45..7d268839df12 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -52,13 +52,15 @@ use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, UInt64Array, }; +use arrow::buffer::NullBuffer; use arrow::compute::kernels::cmp::{eq, not_distinct}; use arrow::compute::{and, concat_batches, take, FilterBuilder}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_array::cast::downcast_array; -use arrow_schema::ArrowError; +use arrow_ord::ord::make_comparator; +use arrow_schema::{ArrowError, SortOptions}; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, @@ -1210,6 +1212,16 @@ fn eq_dyn_null( right: &dyn Array, null_equals_null: bool, ) -> Result { + // Nested datatypes cannot use the underlying not_distinct function and must use a special + // implementation + // + if left.data_type().is_nested() && null_equals_null { + let cmp = make_comparator(left, right, SortOptions::default())?; + let len = left.len().min(right.len()); + let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); + let nulls = NullBuffer::union(left.nulls(), right.nulls()); + return Ok(BooleanArray::new(values, nulls)); + } match (left.data_type(), right.data_type()) { _ if null_equals_null => not_distinct(&left, &right), _ => eq(&left, &right),