diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs index b4d2f3a9b7c0..c617b2a05c26 100644 --- a/arrow-ord/src/rank.rs +++ b/arrow-ord/src/rank.rs @@ -82,26 +82,6 @@ fn primitive_rank( rank_impl(values.len(), to_sort, options, T::compare, T::is_eq) } -#[inline(never)] -fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec { - let len: u32 = array.len().try_into().unwrap(); - - let to_sort: Vec<(bool, u32)> = match array.nulls().filter(|n| n.null_count() > 0) { - Some(n) => n - .valid_indices() - .map(|idx| (array.value(idx), idx as u32)) - .collect(), - None => array.values().iter().zip(0..len).collect(), - }; - rank_impl( - array.len(), - to_sort, - options, - |a: bool, b: bool| a.cmp(&b), - |a: bool, b: bool| a == b, - ) -} - #[inline(never)] fn bytes_rank(array: &GenericByteArray, options: SortOptions) -> Vec { let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) { @@ -162,6 +142,48 @@ where out } +/// Return the index for the rank when ranking boolean array +/// +/// The index is calculated as follows: +/// if is_null is true, the index is 2 +/// if is_null is false and the value is true, the index is 1 +/// otherwise, the index is 0 +/// +/// false is 0 and true is 1 because these are the value when cast to number +#[inline] +fn get_boolean_rank_index(value: bool, is_null: bool) -> usize { + let is_null_num = is_null as usize; + (is_null_num << 1) | (value as usize & !is_null_num) +} + +#[inline(never)] +fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec { + let ranks_index: [u32; 3] = match (options.descending, options.nulls_first) { + // The order is null, true, false + (true, true) => [2, 1, 0], + // The order is true, false, null + (true, false) => [1, 0, 2], + // The order is null, false, true + (false, true) => [1, 2, 0], + // The order is false, true, null + (false, false) => [0, 1, 2], + }; + + match array.nulls().filter(|n| n.null_count() > 0) { + Some(n) => array + .values() + .iter() + .zip(n.iter()) + .map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, !is_valid)]) + .collect::>(), + None => array + .values() + .iter() + .map(|value| ranks_index[value as usize]) + .collect::>(), + } +} + #[cfg(test)] mod tests { use super::*; @@ -204,6 +226,14 @@ mod tests { assert_eq!(res, &[4, 6, 3, 6, 3, 3]); } + #[test] + fn test_get_boolean_rank_index() { + assert_eq!(get_boolean_rank_index(true, true), 2); + assert_eq!(get_boolean_rank_index(false, true), 2); + assert_eq!(get_boolean_rank_index(true, false), 1); + assert_eq!(get_boolean_rank_index(false, false), 0); + } + #[test] fn test_booleans() { let descending = SortOptions { @@ -223,22 +253,22 @@ mod tests { let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), Some(false)]); let res = rank(&a, None).unwrap(); - assert_eq!(res, &[5, 5, 1, 3, 3]); + assert_eq!(res, &[2, 2, 0, 1, 1]); let res = rank(&a, Some(descending)).unwrap(); - assert_eq!(res, &[3, 3, 1, 5, 5]); + assert_eq!(res, &[1, 1, 0, 2, 2]); let res = rank(&a, Some(nulls_last)).unwrap(); - assert_eq!(res, &[4, 4, 5, 2, 2]); + assert_eq!(res, &[1, 1, 2, 0, 0]); let res = rank(&a, Some(nulls_last_descending)).unwrap(); - assert_eq!(res, &[2, 2, 5, 4, 4]); + assert_eq!(res, &[0, 0, 2, 1, 1]); // Test with non-zero null values let nulls = NullBuffer::from(vec![true, true, false, true, true]); let a = BooleanArray::new(vec![true, true, true, false, false].into(), Some(nulls)); let res = rank(&a, None).unwrap(); - assert_eq!(res, &[5, 5, 1, 3, 3]); + assert_eq!(res, &[2, 2, 0, 1, 1]); } #[test]