Skip to content

Commit

Permalink
fix equal_to in PrimitiveGroupValueBuilder (apache#12758)
Browse files Browse the repository at this point in the history
* fix `equal_to` in `PrimitiveGroupValueBuilder`.

* fix typo.

* add uts.

* reduce calling of `is_null`.
  • Loading branch information
Rachelint authored Oct 5, 2024
1 parent 030c4e9 commit 862bb4a
Showing 1 changed file with 98 additions and 9 deletions.
107 changes: 98 additions & 9 deletions datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,28 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
for PrimitiveGroupValueBuilder<T, NULLABLE>
{
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// Perf: skip null check (by short circuit) if input is not ullable
let null_match = if NULLABLE {
self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
} else {
true
};
// Perf: skip null check (by short circuit) if input is not nullable
if NULLABLE {
// In nullable path, we should check if both `exist row` and `input row`
// are null/not null
let is_exist_null = self.nulls.is_null(lhs_row);
let null_match = is_exist_null == array.is_null(rhs_row);
if !null_match {
// If `is_null`s in `exist row` and `input row` don't match, return not equal to
return false;
} else if is_exist_null {
// If `is_null`s in `exist row` and `input row` match, and they are `null`s,
// return equal to
//
// NOTICE: we should not check their values when they are `null`s, because they are
// meaningless actually, and not ensured to be same
//
return true;
}
// Otherwise, we need to check their values
}

null_match
&& self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
}

fn append_val(&mut self, array: &ArrayRef, row: usize) {
Expand Down Expand Up @@ -373,9 +386,13 @@ where
mod tests {
use std::sync::Arc;

use arrow_array::{ArrayRef, StringArray};
use arrow::datatypes::Int64Type;
use arrow_array::{ArrayRef, Int64Array, StringArray};
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
use datafusion_physical_expr::binary_map::OutputType;

use crate::aggregates::group_values::group_column::PrimitiveGroupValueBuilder;

use super::{ByteGroupValueBuilder, GroupColumn};

#[test]
Expand Down Expand Up @@ -422,4 +439,76 @@ mod tests {
])) as ArrayRef;
assert_eq!(&output, &array);
}

#[test]
fn test_nullable_primitive_equal_to() {
// Will cover such cases:
// - exist null, input not null
// - exist null, input null; values not equal
// - exist null, input null; values equal
// - exist not null, input null
// - exist not null, input not null; values not equal
// - exist not null, input not null; values equal

// Define PrimitiveGroupValueBuilder
let mut builder = PrimitiveGroupValueBuilder::<Int64Type, true>::new();
let builder_array = Arc::new(Int64Array::from(vec![
None,
None,
None,
Some(1),
Some(2),
Some(3),
])) as ArrayRef;
builder.append_val(&builder_array, 0);
builder.append_val(&builder_array, 1);
builder.append_val(&builder_array, 2);
builder.append_val(&builder_array, 3);
builder.append_val(&builder_array, 4);
builder.append_val(&builder_array, 5);

// Define input array
let (_, values, _) =
Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)])
.into_parts();

let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(true);
let nulls = NullBuffer::new(boolean_buffer_builder.finish());
let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef;

// Check
assert!(!builder.equal_to(0, &input_array, 0));
assert!(builder.equal_to(1, &input_array, 1));
assert!(builder.equal_to(2, &input_array, 2));
assert!(!builder.equal_to(3, &input_array, 3));
assert!(!builder.equal_to(4, &input_array, 4));
assert!(builder.equal_to(5, &input_array, 5));
}

#[test]
fn test_not_nullable_primitive_equal_to() {
// Will cover such cases:
// - values equal
// - values not equal

// Define PrimitiveGroupValueBuilder
let mut builder = PrimitiveGroupValueBuilder::<Int64Type, false>::new();
let builder_array =
Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
builder.append_val(&builder_array, 0);
builder.append_val(&builder_array, 1);

// Define input array
let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef;

// Check
assert!(builder.equal_to(0, &input_array, 0));
assert!(!builder.equal_to(1, &input_array, 1));
}
}

0 comments on commit 862bb4a

Please sign in to comment.