/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/sorts/cursor.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::cmp::Ordering; |
19 | | |
20 | | use arrow::buffer::ScalarBuffer; |
21 | | use arrow::compute::SortOptions; |
22 | | use arrow::datatypes::ArrowNativeTypeOp; |
23 | | use arrow::row::Rows; |
24 | | use arrow_array::types::ByteArrayType; |
25 | | use arrow_array::{ |
26 | | Array, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait, PrimitiveArray, |
27 | | }; |
28 | | use arrow_buffer::{Buffer, OffsetBuffer}; |
29 | | use datafusion_execution::memory_pool::MemoryReservation; |
30 | | |
31 | | /// A comparable collection of values for use with [`Cursor`] |
32 | | /// |
33 | | /// This is a trait as there are several specialized implementations, such as for |
34 | | /// single columns or for normalized multi column keys ([`Rows`]) |
35 | | pub trait CursorValues { |
36 | | fn len(&self) -> usize; |
37 | | |
38 | | /// Returns true if `l[l_idx] == r[r_idx]` |
39 | | fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool; |
40 | | |
41 | | /// Returns comparison of `l[l_idx]` and `r[r_idx]` |
42 | | fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering; |
43 | | } |
44 | | |
45 | | /// A comparable cursor, used by sort operations |
46 | | /// |
47 | | /// A `Cursor` is a pointer into a collection of rows, stored in |
48 | | /// [`CursorValues`] |
49 | | /// |
50 | | /// ```text |
51 | | /// |
52 | | /// ┌───────────────────────┐ |
53 | | /// │ │ ┌──────────────────────┐ |
54 | | /// │ ┌─────────┐ ┌─────┐ │ ─ ─ ─ ─│ Cursor<T> │ |
55 | | /// │ │ 1 │ │ A │ │ │ └──────────────────────┘ |
56 | | /// │ ├─────────┤ ├─────┤ │ |
57 | | /// │ │ 2 │ │ A │◀─ ┼ ─ ┘ Cursor<T> tracks an |
58 | | /// │ └─────────┘ └─────┘ │ offset within a |
59 | | /// │ ... ... │ CursorValues |
60 | | /// │ │ |
61 | | /// │ ┌─────────┐ ┌─────┐ │ |
62 | | /// │ │ 3 │ │ E │ │ |
63 | | /// │ └─────────┘ └─────┘ │ |
64 | | /// │ │ |
65 | | /// │ CursorValues │ |
66 | | /// └───────────────────────┘ |
67 | | /// |
68 | | /// |
69 | | /// Store logical rows using |
70 | | /// one of several formats, |
71 | | /// with specialized |
72 | | /// implementations |
73 | | /// depending on the column |
74 | | /// types |
75 | | #[derive(Debug)] |
76 | | pub struct Cursor<T: CursorValues> { |
77 | | offset: usize, |
78 | | values: T, |
79 | | } |
80 | | |
81 | | impl<T: CursorValues> Cursor<T> { |
82 | | /// Create a [`Cursor`] from the given [`CursorValues`] |
83 | 684 | pub fn new(values: T) -> Self { |
84 | 684 | Self { offset: 0, values } |
85 | 684 | } |
86 | | |
87 | | /// Returns true if there are no more rows in this cursor |
88 | 14.4k | pub fn is_finished(&self) -> bool { |
89 | 14.4k | self.offset == self.values.len() |
90 | 14.4k | } |
91 | | |
92 | | /// Advance the cursor, returning the previous row index |
93 | 14.5k | pub fn advance(&mut self) -> usize { |
94 | 14.5k | let t = self.offset; |
95 | 14.5k | self.offset += 1; |
96 | 14.5k | t |
97 | 14.5k | } |
98 | | } |
99 | | |
100 | | impl<T: CursorValues> PartialEq for Cursor<T> { |
101 | 7 | fn eq(&self, other: &Self) -> bool { |
102 | 7 | T::eq(&self.values, self.offset, &other.values, other.offset) |
103 | 7 | } |
104 | | } |
105 | | |
106 | | impl<T: CursorValues> Eq for Cursor<T> {} |
107 | | |
108 | | impl<T: CursorValues> PartialOrd for Cursor<T> { |
109 | 0 | fn partial_cmp(&self, other: &Self) -> Option<Ordering> { |
110 | 0 | Some(self.cmp(other)) |
111 | 0 | } |
112 | | } |
113 | | |
114 | | impl<T: CursorValues> Ord for Cursor<T> { |
115 | 31.9k | fn cmp(&self, other: &Self) -> Ordering { |
116 | 31.9k | T::compare(&self.values, self.offset, &other.values, other.offset) |
117 | 31.9k | } |
118 | | } |
119 | | |
120 | | /// Implements [`CursorValues`] for [`Rows`] |
121 | | /// |
122 | | /// Used for sorting when there are multiple columns in the sort key |
123 | | #[derive(Debug)] |
124 | | pub struct RowValues { |
125 | | rows: Rows, |
126 | | |
127 | | /// Tracks for the memory used by in the `Rows` of this |
128 | | /// cursor. Freed on drop |
129 | | #[allow(dead_code)] |
130 | | reservation: MemoryReservation, |
131 | | } |
132 | | |
133 | | impl RowValues { |
134 | | /// Create a new [`RowValues`] from `rows` and a `reservation` |
135 | | /// that tracks its memory. There must be at least one row |
136 | | /// |
137 | | /// Panics if the reservation is not for exactly `rows.size()` |
138 | | /// bytes or if `rows` is empty. |
139 | 11 | pub fn new(rows: Rows, reservation: MemoryReservation) -> Self { |
140 | 11 | assert_eq!( |
141 | 11 | rows.size(), |
142 | 11 | reservation.size(), |
143 | 0 | "memory reservation mismatch" |
144 | | ); |
145 | 11 | assert!(rows.num_rows() > 0); |
146 | 11 | Self { rows, reservation } |
147 | 11 | } |
148 | | } |
149 | | |
150 | | impl CursorValues for RowValues { |
151 | 55 | fn len(&self) -> usize { |
152 | 55 | self.rows.num_rows() |
153 | 55 | } |
154 | | |
155 | 0 | fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { |
156 | 0 | l.rows.row(l_idx) == r.rows.row(r_idx) |
157 | 0 | } |
158 | | |
159 | 45 | fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { |
160 | 45 | l.rows.row(l_idx).cmp(&r.rows.row(r_idx)) |
161 | 45 | } |
162 | | } |
163 | | |
164 | | /// An [`Array`] that can be converted into [`CursorValues`] |
165 | | pub trait CursorArray: Array + 'static { |
166 | | type Values: CursorValues; |
167 | | |
168 | | fn values(&self) -> Self::Values; |
169 | | } |
170 | | |
171 | | impl<T: ArrowPrimitiveType> CursorArray for PrimitiveArray<T> { |
172 | | type Values = PrimitiveValues<T::Native>; |
173 | | |
174 | 653 | fn values(&self) -> Self::Values { |
175 | 653 | PrimitiveValues(self.values().clone()) |
176 | 653 | } |
177 | | } |
178 | | |
179 | | #[derive(Debug)] |
180 | | pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>); |
181 | | |
182 | | impl<T: ArrowNativeTypeOp> CursorValues for PrimitiveValues<T> { |
183 | 14.4k | fn len(&self) -> usize { |
184 | 14.4k | self.0.len() |
185 | 14.4k | } |
186 | | |
187 | 1 | fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { |
188 | 1 | l.0[l_idx].is_eq(r.0[r_idx]) |
189 | 1 | } |
190 | | |
191 | 31.8k | fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { |
192 | 31.8k | l.0[l_idx].compare(r.0[r_idx]) |
193 | 31.8k | } |
194 | | } |
195 | | |
196 | | pub struct ByteArrayValues<T: OffsetSizeTrait> { |
197 | | offsets: OffsetBuffer<T>, |
198 | | values: Buffer, |
199 | | } |
200 | | |
201 | | impl<T: OffsetSizeTrait> ByteArrayValues<T> { |
202 | 108 | fn value(&self, idx: usize) -> &[u8] { |
203 | 108 | assert!(idx < self.len()); |
204 | | // Safety: offsets are valid and checked bounds above |
205 | | unsafe { |
206 | 108 | let start = self.offsets.get_unchecked(idx).as_usize(); |
207 | 108 | let end = self.offsets.get_unchecked(idx + 1).as_usize(); |
208 | 108 | self.values.get_unchecked(start..end) |
209 | 108 | } |
210 | 108 | } |
211 | | } |
212 | | |
213 | | impl<T: OffsetSizeTrait> CursorValues for ByteArrayValues<T> { |
214 | 132 | fn len(&self) -> usize { |
215 | 132 | self.offsets.len() - 1 |
216 | 132 | } |
217 | | |
218 | 0 | fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { |
219 | 0 | l.value(l_idx) == r.value(r_idx) |
220 | 0 | } |
221 | | |
222 | 54 | fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { |
223 | 54 | l.value(l_idx).cmp(r.value(r_idx)) |
224 | 54 | } |
225 | | } |
226 | | |
227 | | impl<T: ByteArrayType> CursorArray for GenericByteArray<T> { |
228 | | type Values = ByteArrayValues<T::Offset>; |
229 | | |
230 | 12 | fn values(&self) -> Self::Values { |
231 | 12 | ByteArrayValues { |
232 | 12 | offsets: self.offsets().clone(), |
233 | 12 | values: self.values().clone(), |
234 | 12 | } |
235 | 12 | } |
236 | | } |
237 | | |
238 | | /// A collection of sorted, nullable [`CursorValues`] |
239 | | /// |
240 | | /// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering |
241 | | #[derive(Debug)] |
242 | | pub struct ArrayValues<T: CursorValues> { |
243 | | values: T, |
244 | | // If nulls first, the first non-null index |
245 | | // Otherwise, the first null index |
246 | | null_threshold: usize, |
247 | | options: SortOptions, |
248 | | } |
249 | | |
250 | | impl<T: CursorValues> ArrayValues<T> { |
251 | | /// Create a new [`ArrayValues`] from the provided `values` sorted according |
252 | | /// to `options`. |
253 | | /// |
254 | | /// Panics if the array is empty |
255 | 665 | pub fn new<A: CursorArray<Values = T>>(options: SortOptions, array: &A) -> Self { |
256 | 665 | assert!(array.len() > 0, "Empty array passed to FieldCursor"0 ); |
257 | 665 | let null_threshold = match options.nulls_first { |
258 | 665 | true => array.null_count(), |
259 | 0 | false => array.len() - array.null_count(), |
260 | | }; |
261 | | |
262 | 665 | Self { |
263 | 665 | values: array.values(), |
264 | 665 | null_threshold, |
265 | 665 | options, |
266 | 665 | } |
267 | 665 | } |
268 | | |
269 | 63.8k | fn is_null(&self, idx: usize) -> bool { |
270 | 63.8k | (idx < self.null_threshold) == self.options.nulls_first |
271 | 63.8k | } |
272 | | } |
273 | | |
274 | | impl<T: CursorValues> CursorValues for ArrayValues<T> { |
275 | 14.4k | fn len(&self) -> usize { |
276 | 14.4k | self.values.len() |
277 | 14.4k | } |
278 | | |
279 | 7 | fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { |
280 | 7 | match (l.is_null(l_idx), r.is_null(r_idx)) { |
281 | 6 | (true, true) => true, |
282 | 1 | (false, false) => T::eq(&l.values, l_idx, &r.values, r_idx), |
283 | 0 | _ => false, |
284 | | } |
285 | 7 | } |
286 | | |
287 | 31.9k | fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { |
288 | 31.9k | match (l.is_null(l_idx), r.is_null(r_idx)) { |
289 | 6 | (true, true) => Ordering::Equal, |
290 | 2 | (true, false) => match l.options.nulls_first { |
291 | 2 | true => Ordering::Less, |
292 | 0 | false => Ordering::Greater, |
293 | | }, |
294 | 4 | (false, true) => match l.options.nulls_first { |
295 | 0 | true => Ordering::Greater, |
296 | 4 | false => Ordering::Less, |
297 | | }, |
298 | 31.8k | (false, false) => match l.options.descending { |
299 | 800 | true => T::compare(&r.values, r_idx, &l.values, l_idx), |
300 | 31.0k | false => T::compare(&l.values, l_idx, &r.values, r_idx), |
301 | | }, |
302 | | } |
303 | 31.9k | } |
304 | | } |
305 | | |
306 | | #[cfg(test)] |
307 | | mod tests { |
308 | | use super::*; |
309 | | |
310 | 8 | fn new_primitive( |
311 | 8 | options: SortOptions, |
312 | 8 | values: ScalarBuffer<i32>, |
313 | 8 | null_count: usize, |
314 | 8 | ) -> Cursor<ArrayValues<PrimitiveValues<i32>>> { |
315 | 8 | let null_threshold = match options.nulls_first { |
316 | 4 | true => null_count, |
317 | 4 | false => values.len() - null_count, |
318 | | }; |
319 | | |
320 | 8 | let values = ArrayValues { |
321 | 8 | values: PrimitiveValues(values), |
322 | 8 | null_threshold, |
323 | 8 | options, |
324 | 8 | }; |
325 | 8 | |
326 | 8 | Cursor::new(values) |
327 | 8 | } |
328 | | |
329 | | #[test] |
330 | 1 | fn test_primitive_nulls_first() { |
331 | 1 | let options = SortOptions { |
332 | 1 | descending: false, |
333 | 1 | nulls_first: true, |
334 | 1 | }; |
335 | 1 | |
336 | 1 | let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]); |
337 | 1 | let mut a = new_primitive(options, buffer, 1); |
338 | 1 | let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]); |
339 | 1 | let mut b = new_primitive(options, buffer, 2); |
340 | 1 | |
341 | 1 | // NULL == NULL |
342 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
343 | 1 | assert_eq!(a, b); |
344 | | |
345 | | // NULL == NULL |
346 | 1 | b.advance(); |
347 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
348 | 1 | assert_eq!(a, b); |
349 | | |
350 | | // NULL < -2 |
351 | 1 | b.advance(); |
352 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
353 | | |
354 | | // 1 > -2 |
355 | 1 | a.advance(); |
356 | 1 | assert_eq!(a.cmp(&b), Ordering::Greater); |
357 | | |
358 | | // 1 > -1 |
359 | 1 | b.advance(); |
360 | 1 | assert_eq!(a.cmp(&b), Ordering::Greater); |
361 | | |
362 | | // 1 == 1 |
363 | 1 | b.advance(); |
364 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
365 | 1 | assert_eq!(a, b); |
366 | | |
367 | | // 9 > 1 |
368 | 1 | b.advance(); |
369 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
370 | | |
371 | | // 9 > 2 |
372 | 1 | a.advance(); |
373 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
374 | | |
375 | 1 | let options = SortOptions { |
376 | 1 | descending: false, |
377 | 1 | nulls_first: false, |
378 | 1 | }; |
379 | 1 | |
380 | 1 | let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]); |
381 | 1 | let mut a = new_primitive(options, buffer, 2); |
382 | 1 | let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]); |
383 | 1 | let mut b = new_primitive(options, buffer, 2); |
384 | 1 | |
385 | 1 | // 0 > -1 |
386 | 1 | assert_eq!(a.cmp(&b), Ordering::Greater); |
387 | | |
388 | | // 0 < NULL |
389 | 1 | b.advance(); |
390 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
391 | | |
392 | | // 1 < NULL |
393 | 1 | a.advance(); |
394 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
395 | | |
396 | | // NULL = NULL |
397 | 1 | a.advance(); |
398 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
399 | 1 | assert_eq!(a, b); |
400 | | |
401 | 1 | let options = SortOptions { |
402 | 1 | descending: true, |
403 | 1 | nulls_first: false, |
404 | 1 | }; |
405 | 1 | |
406 | 1 | let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]); |
407 | 1 | let mut a = new_primitive(options, buffer, 3); |
408 | 1 | let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]); |
409 | 1 | let mut b = new_primitive(options, buffer, 2); |
410 | 1 | |
411 | 1 | // 6 > 67 |
412 | 1 | assert_eq!(a.cmp(&b), Ordering::Greater); |
413 | | |
414 | | // 6 < -3 |
415 | 1 | b.advance(); |
416 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
417 | | |
418 | | // 6 < NULL |
419 | 1 | b.advance(); |
420 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
421 | | |
422 | | // 6 < NULL |
423 | 1 | b.advance(); |
424 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
425 | | |
426 | | // NULL == NULL |
427 | 1 | a.advance(); |
428 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
429 | 1 | assert_eq!(a, b); |
430 | | |
431 | 1 | let options = SortOptions { |
432 | 1 | descending: true, |
433 | 1 | nulls_first: true, |
434 | 1 | }; |
435 | 1 | |
436 | 1 | let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]); |
437 | 1 | let mut a = new_primitive(options, buffer, 2); |
438 | 1 | let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]); |
439 | 1 | let mut b = new_primitive(options, buffer, 1); |
440 | 1 | |
441 | 1 | // NULL == NULL |
442 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
443 | 1 | assert_eq!(a, b); |
444 | | |
445 | | // NULL == NULL |
446 | 1 | a.advance(); |
447 | 1 | assert_eq!(a.cmp(&b), Ordering::Equal); |
448 | 1 | assert_eq!(a, b); |
449 | | |
450 | | // NULL < 4546 |
451 | 1 | b.advance(); |
452 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
453 | | |
454 | | // 6 > 4546 |
455 | 1 | a.advance(); |
456 | 1 | assert_eq!(a.cmp(&b), Ordering::Greater); |
457 | | |
458 | | // 6 < -3 |
459 | 1 | b.advance(); |
460 | 1 | assert_eq!(a.cmp(&b), Ordering::Less); |
461 | 1 | } |
462 | | } |