diff --git a/arrow/benches/partition_kernels.rs b/arrow/benches/partition_kernels.rs index 6a9ce709d33c..ae55fbdad22c 100644 --- a/arrow/benches/partition_kernels.rs +++ b/arrow/benches/partition_kernels.rs @@ -48,7 +48,11 @@ fn bench_partition(sorted_columns: &[ArrayRef]) { }) .collect::>(); - criterion::black_box(lexicographical_partition_ranges(&columns).unwrap()); + criterion::black_box( + lexicographical_partition_ranges(&columns) + .unwrap() + .collect::>(), + ); } fn create_sorted_low_cardinality_data(length: usize) -> Vec { diff --git a/arrow/src/compute/kernels/partition.rs b/arrow/src/compute/kernels/partition.rs index e91f80bb558f..ad35e9239e3f 100644 --- a/arrow/src/compute/kernels/partition.rs +++ b/arrow/src/compute/kernels/partition.rs @@ -21,6 +21,7 @@ use crate::compute::kernels::sort::LexicographicalComparator; use crate::compute::SortColumn; use crate::error::{ArrowError, Result}; use std::cmp::Ordering; +use std::iter::Iterator; use std::ops::Range; /// Given a list of already sorted columns, find partition ranges that would partition @@ -34,65 +35,71 @@ use std::ops::Range; /// range. pub fn lexicographical_partition_ranges( columns: &[SortColumn], -) -> Result>> { - let partition_points = lexicographical_partition_points(columns)?; - Ok(partition_points - .iter() - .zip(partition_points[1..].iter()) - .map(|(&start, &end)| Range { start, end }) - .collect()) +) -> Result> + '_> { + LexicographicalPartitionIterator::try_new(columns) } -/// Given a list of already sorted columns, find partition ranges that would partition -/// lexicographically equal values across columns. -/// -/// Here LexicographicalComparator is used in conjunction with binary -/// search so the columns *MUST* be pre-sorted already. -/// -/// The returned vec would be of size k+1 where k is cardinality of the sorted values; the first and -/// last value would be 0 and n. -fn lexicographical_partition_points(columns: &[SortColumn]) -> Result> { - if columns.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "Sort requires at least one column".to_string(), - )); - } - let row_count = columns[0].values.len(); - if columns.iter().any(|item| item.values.len() != row_count) { - return Err(ArrowError::ComputeError( - "Lexical sort columns have different row counts".to_string(), - )); - }; +struct LexicographicalPartitionIterator<'a> { + comparator: LexicographicalComparator<'a>, + num_rows: usize, + previous_partition_point: usize, + partition_point: usize, + value_indices: Vec, +} + +impl<'a> LexicographicalPartitionIterator<'a> { + fn try_new(columns: &'a [SortColumn]) -> Result { + if columns.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Sort requires at least one column".to_string(), + )); + } + let num_rows = columns[0].values.len(); + if columns.iter().any(|item| item.values.len() != num_rows) { + return Err(ArrowError::ComputeError( + "Lexical sort columns have different row counts".to_string(), + )); + }; - let mut result = vec![]; - if row_count == 0 { - return Ok(result); + let comparator = LexicographicalComparator::try_new(columns)?; + let value_indices = (0..num_rows).collect::>(); + Ok(LexicographicalPartitionIterator { + comparator, + num_rows, + previous_partition_point: 0, + partition_point: 0, + value_indices, + }) } +} - let lexicographical_comparator = LexicographicalComparator::try_new(columns)?; - let value_indices = (0..row_count).collect::>(); +impl<'a> Iterator for LexicographicalPartitionIterator<'a> { + type Item = Range; - let mut previous_partition_point = 0; - result.push(previous_partition_point); - while previous_partition_point < row_count { - // invariant: - // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point] - // so in order to save time we can do binary search on the value_indices[previous_partition_point..] - // and find when any value is greater than value_indices[previous_partition_point]; because we are using - // new indices, the new offset is _added_ to the previous_partition_point. - // - // be careful that idx is of type &usize which points to the actual value within value_indices, which itself - // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the - // original columnar data. - previous_partition_point += value_indices[previous_partition_point..] - .partition_point(|idx| { - lexicographical_comparator.compare(idx, &previous_partition_point) - != Ordering::Greater - }); - result.push(previous_partition_point); + fn next(&mut self) -> Option { + if self.partition_point < self.num_rows { + // invariant: + // value_indices[0..previous_partition_point] all are values <= value_indices[previous_partition_point] + // so in order to save time we can do binary search on the value_indices[previous_partition_point..] + // and find when any value is greater than value_indices[previous_partition_point]; because we are using + // new indices, the new offset is _added_ to the previous_partition_point. + // + // be careful that idx is of type &usize which points to the actual value within value_indices, which itself + // contains usize (0..row_count), providing access to lexicographical_comparator as pointers into the + // original columnar data. + self.partition_point += self.value_indices[self.partition_point..] + .partition_point(|idx| { + self.comparator.compare(idx, &self.partition_point) + != Ordering::Greater + }); + let start = self.previous_partition_point; + let end = self.partition_point; + self.previous_partition_point = self.partition_point; + Some(Range { start, end }) + } else { + None + } } - - Ok(result) } #[cfg(test)] @@ -104,16 +111,16 @@ mod tests { use std::sync::Arc; #[test] - fn test_lexicographical_partition_points_empty() { + fn test_lexicographical_partition_ranges_empty() { let input = vec![]; assert!( - lexicographical_partition_points(&input).is_err(), - "lexicographical_partition_points should reject columns with empty rows" + lexicographical_partition_ranges(&input).is_err(), + "lexicographical_partition_ranges should reject columns with empty rows" ); } #[test] - fn test_lexicographical_partition_points_unaligned_rows() { + fn test_lexicographical_partition_ranges_unaligned_rows() { let input = vec![ SortColumn { values: Arc::new(Int64Array::from(vec![None, Some(-1)])) as ArrayRef, @@ -125,8 +132,8 @@ mod tests { }, ]; assert!( - lexicographical_partition_points(&input).is_err(), - "lexicographical_partition_points should reject columns with different row counts" + lexicographical_partition_ranges(&input).is_err(), + "lexicographical_partition_ranges should reject columns with different row counts" ); } @@ -140,15 +147,11 @@ mod tests { nulls_first: true, }), }]; - { - let results = lexicographical_partition_points(&input)?; - assert_eq!(vec![0, 1, 8, 9], results); - } { let results = lexicographical_partition_ranges(&input)?; assert_eq!( vec![(0_usize..1_usize), (1_usize..8_usize), (8_usize..9_usize)], - results + results.collect::>() ); } Ok(()) @@ -163,13 +166,10 @@ mod tests { nulls_first: true, }), }]; - { - let results = lexicographical_partition_points(&input)?; - assert_eq!(vec![0, 1000], results); - } + { let results = lexicographical_partition_ranges(&input)?; - assert_eq!(vec![(0_usize..1000_usize)], results); + assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); } Ok(()) } @@ -192,13 +192,9 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_points(&input)?; - assert_eq!(vec![0, 1000], results); - } { let results = lexicographical_partition_ranges(&input)?; - assert_eq!(vec![(0_usize..1000_usize)], results); + assert_eq!(vec![(0_usize..1000_usize)], results.collect::>()); } Ok(()) } @@ -222,13 +218,12 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_points(&input)?; - assert_eq!(vec![0, 1, 2], results); - } { let results = lexicographical_partition_ranges(&input)?; - assert_eq!(vec![(0_usize..1_usize), (1_usize..2_usize)], results); + assert_eq!( + vec![(0_usize..1_usize), (1_usize..2_usize)], + results.collect::>() + ); } Ok(()) } @@ -256,15 +251,11 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_points(&input)?; - assert_eq!(vec![0, 1, 2, 3], results); - } { let results = lexicographical_partition_ranges(&input)?; assert_eq!( vec![(0_usize..1_usize), (1_usize..2_usize), (2_usize..3_usize),], - results + results.collect::>() ); } Ok(()) @@ -298,15 +289,11 @@ mod tests { }), }, ]; - { - let results = lexicographical_partition_points(&input)?; - assert_eq!(vec![0, 1, 3, 4], results); - } { let results = lexicographical_partition_ranges(&input)?; assert_eq!( vec![(0_usize..1_usize), (1_usize..3_usize), (3_usize..4_usize),], - results + results.collect::>() ); } Ok(())