From 862bb4ae085b08ac8a29e8f8314932631ce27c08 Mon Sep 17 00:00:00 2001
From: kamille <caoruiqiu.crq@antgroup.com>
Date: Sat, 5 Oct 2024 18:44:34 +0800
Subject: [PATCH] fix `equal_to` in `PrimitiveGroupValueBuilder` (#12758)

* fix `equal_to` in `PrimitiveGroupValueBuilder`.

* fix typo.

* add uts.

* reduce calling of `is_null`.
---
 .../aggregates/group_values/group_column.rs   | 107 ++++++++++++++++--
 1 file changed, 98 insertions(+), 9 deletions(-)

diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
index 15c93262968e..aa246ac95b8b 100644
--- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
+++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs
@@ -91,15 +91,28 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
     for PrimitiveGroupValueBuilder<T, NULLABLE>
 {
     fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
-        // Perf: skip null check (by short circuit) if input is not ullable
-        let null_match = if NULLABLE {
-            self.nulls.is_null(lhs_row) == array.is_null(rhs_row)
-        } else {
-            true
-        };
+        // Perf: skip null check (by short circuit) if input is not nullable
+        if NULLABLE {
+            // In nullable path, we should check if both `exist row` and `input row`
+            // are null/not null
+            let is_exist_null = self.nulls.is_null(lhs_row);
+            let null_match = is_exist_null == array.is_null(rhs_row);
+            if !null_match {
+                // If `is_null`s in `exist row` and `input row` don't match, return not equal to
+                return false;
+            } else if is_exist_null {
+                // If `is_null`s in `exist row` and `input row` match, and they are `null`s,
+                // return equal to
+                //
+                // NOTICE: we should not check their values when they are `null`s, because they are
+                // meaningless actually, and not ensured to be same
+                //
+                return true;
+            }
+            // Otherwise, we need to check their values
+        }
 
-        null_match
-            && self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
+        self.group_values[lhs_row] == array.as_primitive::<T>().value(rhs_row)
     }
 
     fn append_val(&mut self, array: &ArrayRef, row: usize) {
@@ -373,9 +386,13 @@ where
 mod tests {
     use std::sync::Arc;
 
-    use arrow_array::{ArrayRef, StringArray};
+    use arrow::datatypes::Int64Type;
+    use arrow_array::{ArrayRef, Int64Array, StringArray};
+    use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
     use datafusion_physical_expr::binary_map::OutputType;
 
+    use crate::aggregates::group_values::group_column::PrimitiveGroupValueBuilder;
+
     use super::{ByteGroupValueBuilder, GroupColumn};
 
     #[test]
@@ -422,4 +439,76 @@ mod tests {
         ])) as ArrayRef;
         assert_eq!(&output, &array);
     }
+
+    #[test]
+    fn test_nullable_primitive_equal_to() {
+        // Will cover such cases:
+        //   - exist null, input not null
+        //   - exist null, input null; values not equal
+        //   - exist null, input null; values equal
+        //   - exist not null, input null
+        //   - exist not null, input not null; values not equal
+        //   - exist not null, input not null; values equal
+
+        // Define PrimitiveGroupValueBuilder
+        let mut builder = PrimitiveGroupValueBuilder::<Int64Type, true>::new();
+        let builder_array = Arc::new(Int64Array::from(vec![
+            None,
+            None,
+            None,
+            Some(1),
+            Some(2),
+            Some(3),
+        ])) as ArrayRef;
+        builder.append_val(&builder_array, 0);
+        builder.append_val(&builder_array, 1);
+        builder.append_val(&builder_array, 2);
+        builder.append_val(&builder_array, 3);
+        builder.append_val(&builder_array, 4);
+        builder.append_val(&builder_array, 5);
+
+        // Define input array
+        let (_, values, _) =
+            Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)])
+                .into_parts();
+
+        let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
+        boolean_buffer_builder.append(true);
+        boolean_buffer_builder.append(false);
+        boolean_buffer_builder.append(false);
+        boolean_buffer_builder.append(false);
+        boolean_buffer_builder.append(true);
+        boolean_buffer_builder.append(true);
+        let nulls = NullBuffer::new(boolean_buffer_builder.finish());
+        let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef;
+
+        // Check
+        assert!(!builder.equal_to(0, &input_array, 0));
+        assert!(builder.equal_to(1, &input_array, 1));
+        assert!(builder.equal_to(2, &input_array, 2));
+        assert!(!builder.equal_to(3, &input_array, 3));
+        assert!(!builder.equal_to(4, &input_array, 4));
+        assert!(builder.equal_to(5, &input_array, 5));
+    }
+
+    #[test]
+    fn test_not_nullable_primitive_equal_to() {
+        // Will cover such cases:
+        //   - values equal
+        //   - values not equal
+
+        // Define PrimitiveGroupValueBuilder
+        let mut builder = PrimitiveGroupValueBuilder::<Int64Type, false>::new();
+        let builder_array =
+            Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
+        builder.append_val(&builder_array, 0);
+        builder.append_val(&builder_array, 1);
+
+        // Define input array
+        let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef;
+
+        // Check
+        assert!(builder.equal_to(0, &input_array, 0));
+        assert!(!builder.equal_to(1, &input_array, 1));
+    }
 }