Skip to content

Commit

Permalink
fix distinct for struct.
Browse files Browse the repository at this point in the history
  • Loading branch information
my-vegetable-has-exploded committed Dec 27, 2023
1 parent 1b19ec5 commit 4f8522d
Showing 1 changed file with 50 additions and 53 deletions.
103 changes: 50 additions & 53 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,15 @@ fn compare_op_nulls(
len: usize,
) -> Result<Option<NullBuffer>, ArrowError> {
use arrow_schema::DataType::*;
if matches!(op, Op::Distinct | Op::NotDistinct) {
// for [not]Distinct, the result is never null
return Ok(None);
}

let l_t = l.data_type();
let r_t = r.data_type();
let l_nulls = l.logical_nulls().filter(|n| n.null_count() > 0);
let r_nulls = r.logical_nulls().filter(|n| n.null_count() > 0);
// for [not]Distinct, the result is never null
match op {
Op::Distinct | Op::NotDistinct => {
return Ok(None);
}
_ => {}
}
let nulls = match (l_nulls, l_s, r_nulls, r_s) {
// Either both sides are scalar or neither side is scalar
(Some(l_nulls), true, Some(r_nulls), true)
Expand Down Expand Up @@ -267,21 +265,7 @@ fn compare_op_values(
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
let r_t = r.data_type();

if l_t.is_nested() {
if !l_t.equals_datatype(r_t) {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
match (l_t, op) {
(Struct(_), Op::Equal | Op::NotEqual | Op::Distinct | Op::NotDistinct) => {}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
}
}
} else if r_t != l_t {
if !l_t.equals_datatype(r_t) {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: {l_t} {op} {r_t}"
)));
Expand Down Expand Up @@ -380,12 +364,22 @@ fn compare_op_struct_values(
r_s: bool,
len: usize,
) -> Result<BooleanBuffer, ArrowError> {
// when one of field is equal, the result is false for not equal
// when one of field is not equal(notdistinct), the result is false for equal(notdistinct)
// so we use neg to reverse the result of equal when handle not equal
let neg = match op {
Op::Equal | Op::NotDistinct => false,
Op::NotEqual | Op::Distinct => true,
_ => unreachable!(),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Invalid comparison operation: Struct {op} Struct"
)))
}
};

let op = match op {
Op::NotEqual => Op::Equal,
Op::Distinct => Op::NotDistinct,
_ => op,
};

let l = l.as_struct();
Expand All @@ -396,7 +390,7 @@ fn compare_op_struct_values(
.columns()
.iter()
.zip(r.columns().iter())
.map(|(col_l, col_r)| compare_op_values(Op::Equal, col_l, l_s, col_r, r_s, len))
.map(|(col_l, col_r)| compare_op_values(op, col_l, l_s, col_r, r_s, len))
.collect::<Result<Vec<BooleanBuffer>, ArrowError>>()?;
// combine the result of each field
let equality = child_values
Expand Down Expand Up @@ -852,72 +846,75 @@ mod tests {
assert_eq!(eq(&left, &right).unwrap_err().to_string(), "Invalid argument error: Invalid comparison operation: Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) == Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"b\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])");

// test struct('a') <= struct('a')
assert_eq!(lt(&left, &left).unwrap_err().to_string(), "Invalid argument error: Invalid comparison operation: Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) < Struct([Field { name: \"a\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])");
assert_eq!(
lt(&left, &left).unwrap_err().to_string(),
"Invalid argument error: Invalid comparison operation: Struct < Struct"
);
}

#[test]
fn test_struct_compare() {
// test struct('a', 'b')、struct('a', 'b'), the null buffer is 0b0111
// left b[2] is different from right b[2]
let left_a = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, false, true, true].into()),
vec![0, 1, 2, 3, 4, 5, 6, 7].into(),
Some(vec![true, false, true, true, false, true, true, false].into()),
));
let right_a = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, false, true, true].into()),
vec![0, 1, 2, 3, 4, 5, 6, 72].into(),
Some(vec![true, false, true, true, false, true, true, false].into()),
));
let left_b = Arc::new(Int32Array::new(
vec![0, 1, 20, 3].into(),
Some(vec![true, true, true, true].into()),
vec![0, 1, 2, 3, 4, 5, 7, 7].into(),
Some(vec![true, true, true, true, true, true, true, true].into()),
));
let right_b = Arc::new(Int32Array::new(
vec![0, 1, 2, 3].into(),
Some(vec![true, true, true, true].into()),
vec![0, 1, 20, 13, 72, 6, 6, 7].into(),
Some(vec![true, true, true, true, true, true, false, true].into()),
));
let field_a = Arc::new(Field::new("a", DataType::Int32, true));
let field_b = Arc::new(Field::new("b", DataType::Int32, true));
// left [{a: 0, b: 0}, {a: NULL, b: 1}, {a: 2, b: 2}, NULL({a: 3, b: 3}), {a: NULL, b: 4}, NULL({a: 5, b: 5}), {a:6, b: 7}, {a: NULL, b: 7}]
let left_struct = StructArray::from((
vec![
(field_a.clone(), left_a.clone() as ArrayRef),
(field_b.clone(), left_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
Buffer::from([0b11010111]),
));
// right [{a: 0, b: 0}, {a: NULL, b: 1}, {a: 2, b: 20}, Null({a: 3, b: 13}), {a: NULL, b: 72}, Null({a: 5, b: 6}), {a:6, b: Null}, {a: NULL, b: 7}]
let right_struct = StructArray::from((
vec![
(field_a.clone(), right_a.clone() as ArrayRef),
(field_b.clone(), right_b.clone() as ArrayRef),
],
Buffer::from([0b0111]),
Buffer::from([0b11010111]),
));
let expected = BooleanArray::new(
vec![true, true, false, true].into(),
// a[1] is none in child, struct[3] is none in parent
Some(vec![true, false, true, false].into()),
vec![true, true, false, false, false, false, false, false].into(),
Some(vec![true, false, true, false, false, false, false, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, false, true, false].into(),
Some(vec![true, false, true, false].into()),
vec![false, false, true, true, true, true, true, true].into(),
Some(vec![true, false, true, false, false, false, false, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
// left[0] equals to right[0], left b[1] is not distinct from right b[1], left b[2] is distinct from right b[2], struct[3] is none in parent
vec![false, false, true, false].into(),
vec![false, false, true, false, true, false, true, false].into(),
None,
);
assert_eq!(distinct(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(distinct(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(vec![true, true, false, true].into(), None);
let expected = BooleanArray::new(
vec![true, true, false, true, false, true, false, true].into(),
None,
);
assert_eq!(not_distinct(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(not_distinct(&right_struct, &left_struct).unwrap(), expected);

let sub_struct_fields = left_struct.fields().clone();

// test struct('a' , struct('suba', 'subb')) 、 struct('a', struct('suba', 'subb')), where the right subb1[2] different from left subb[2],the null buffer is 0b0111
let left_struct = StructArray::from((
vec![
(field_a.clone(), left_a.clone() as ArrayRef),
Expand All @@ -930,7 +927,7 @@ mod tests {
Arc::new(left_struct) as ArrayRef,
),
],
Buffer::from([0b0111]),
Buffer::from([0b11010111]),
));
let right_struct = StructArray::from((
vec![
Expand All @@ -944,17 +941,17 @@ mod tests {
Arc::new(right_struct) as ArrayRef,
),
],
Buffer::from([0b0111]),
Buffer::from([0b11010111]),
));
let expected = BooleanArray::new(
vec![true, false, false, true].into(),
Some(vec![true, false, true, false].into()),
vec![true, true, false, false, false, false, false, false].into(),
Some(vec![true, false, true, false, false, false, false, false].into()),
);
assert_eq!(eq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(eq(&right_struct, &left_struct).unwrap(), expected);
let expected = BooleanArray::new(
vec![false, true, true, false].into(),
Some(vec![true, false, true, false].into()),
vec![false, false, true, true, true, true, true, true].into(),
Some(vec![true, false, true, false, false, false, false, false].into()),
);
assert_eq!(neq(&left_struct, &right_struct).unwrap(), expected);
assert_eq!(neq(&right_struct, &left_struct).unwrap(), expected);
Expand Down

0 comments on commit 4f8522d

Please sign in to comment.