Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/sorts/cursor.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::cmp::Ordering;
19
20
use arrow::buffer::ScalarBuffer;
21
use arrow::compute::SortOptions;
22
use arrow::datatypes::ArrowNativeTypeOp;
23
use arrow::row::Rows;
24
use arrow_array::types::ByteArrayType;
25
use arrow_array::{
26
    Array, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait, PrimitiveArray,
27
};
28
use arrow_buffer::{Buffer, OffsetBuffer};
29
use datafusion_execution::memory_pool::MemoryReservation;
30
31
/// A comparable collection of values for use with [`Cursor`]
32
///
33
/// This is a trait as there are several specialized implementations, such as for
34
/// single columns or for normalized multi column keys ([`Rows`])
35
pub trait CursorValues {
36
    fn len(&self) -> usize;
37
38
    /// Returns true if `l[l_idx] == r[r_idx]`
39
    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool;
40
41
    /// Returns comparison of `l[l_idx]` and `r[r_idx]`
42
    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering;
43
}
44
45
/// A comparable cursor, used by sort operations
46
///
47
/// A `Cursor` is a pointer into a collection of rows, stored in
48
/// [`CursorValues`]
49
///
50
/// ```text
51
///
52
/// ┌───────────────────────┐
53
/// │                       │           ┌──────────────────────┐
54
/// │ ┌─────────┐ ┌─────┐   │    ─ ─ ─ ─│      Cursor<T>       │
55
/// │ │    1    │ │  A  │   │   │       └──────────────────────┘
56
/// │ ├─────────┤ ├─────┤   │
57
/// │ │    2    │ │  A  │◀─ ┼ ─ ┘          Cursor<T> tracks an
58
/// │ └─────────┘ └─────┘   │                offset within a
59
/// │     ...       ...     │                  CursorValues
60
/// │                       │
61
/// │ ┌─────────┐ ┌─────┐   │
62
/// │ │    3    │ │  E  │   │
63
/// │ └─────────┘ └─────┘   │
64
/// │                       │
65
/// │     CursorValues      │
66
/// └───────────────────────┘
67
///
68
///
69
/// Store logical rows using
70
/// one of several  formats,
71
/// with specialized
72
/// implementations
73
/// depending on the column
74
/// types
75
#[derive(Debug)]
76
pub struct Cursor<T: CursorValues> {
77
    offset: usize,
78
    values: T,
79
}
80
81
impl<T: CursorValues> Cursor<T> {
82
    /// Create a [`Cursor`] from the given [`CursorValues`]
83
684
    pub fn new(values: T) -> Self {
84
684
        Self { offset: 0, values }
85
684
    }
86
87
    /// Returns true if there are no more rows in this cursor
88
14.4k
    pub fn is_finished(&self) -> bool {
89
14.4k
        self.offset == self.values.len()
90
14.4k
    }
91
92
    /// Advance the cursor, returning the previous row index
93
14.5k
    pub fn advance(&mut self) -> usize {
94
14.5k
        let t = self.offset;
95
14.5k
        self.offset += 1;
96
14.5k
        t
97
14.5k
    }
98
}
99
100
impl<T: CursorValues> PartialEq for Cursor<T> {
101
7
    fn eq(&self, other: &Self) -> bool {
102
7
        T::eq(&self.values, self.offset, &other.values, other.offset)
103
7
    }
104
}
105
106
impl<T: CursorValues> Eq for Cursor<T> {}
107
108
impl<T: CursorValues> PartialOrd for Cursor<T> {
109
0
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
110
0
        Some(self.cmp(other))
111
0
    }
112
}
113
114
impl<T: CursorValues> Ord for Cursor<T> {
115
31.9k
    fn cmp(&self, other: &Self) -> Ordering {
116
31.9k
        T::compare(&self.values, self.offset, &other.values, other.offset)
117
31.9k
    }
118
}
119
120
/// Implements [`CursorValues`] for [`Rows`]
121
///
122
/// Used for sorting when there are multiple columns in the sort key
123
#[derive(Debug)]
124
pub struct RowValues {
125
    rows: Rows,
126
127
    /// Tracks for the memory used by in the `Rows` of this
128
    /// cursor. Freed on drop
129
    #[allow(dead_code)]
130
    reservation: MemoryReservation,
131
}
132
133
impl RowValues {
134
    /// Create a new [`RowValues`] from `rows` and a `reservation`
135
    /// that tracks its memory. There must be at least one row
136
    ///
137
    /// Panics if the reservation is not for exactly `rows.size()`
138
    /// bytes or if `rows` is empty.
139
11
    pub fn new(rows: Rows, reservation: MemoryReservation) -> Self {
140
11
        assert_eq!(
141
11
            rows.size(),
142
11
            reservation.size(),
143
0
            "memory reservation mismatch"
144
        );
145
11
        assert!(rows.num_rows() > 0);
146
11
        Self { rows, reservation }
147
11
    }
148
}
149
150
impl CursorValues for RowValues {
151
55
    fn len(&self) -> usize {
152
55
        self.rows.num_rows()
153
55
    }
154
155
0
    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
156
0
        l.rows.row(l_idx) == r.rows.row(r_idx)
157
0
    }
158
159
45
    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
160
45
        l.rows.row(l_idx).cmp(&r.rows.row(r_idx))
161
45
    }
162
}
163
164
/// An [`Array`] that can be converted into [`CursorValues`]
165
pub trait CursorArray: Array + 'static {
166
    type Values: CursorValues;
167
168
    fn values(&self) -> Self::Values;
169
}
170
171
impl<T: ArrowPrimitiveType> CursorArray for PrimitiveArray<T> {
172
    type Values = PrimitiveValues<T::Native>;
173
174
653
    fn values(&self) -> Self::Values {
175
653
        PrimitiveValues(self.values().clone())
176
653
    }
177
}
178
179
#[derive(Debug)]
180
pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>);
181
182
impl<T: ArrowNativeTypeOp> CursorValues for PrimitiveValues<T> {
183
14.4k
    fn len(&self) -> usize {
184
14.4k
        self.0.len()
185
14.4k
    }
186
187
1
    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
188
1
        l.0[l_idx].is_eq(r.0[r_idx])
189
1
    }
190
191
31.8k
    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
192
31.8k
        l.0[l_idx].compare(r.0[r_idx])
193
31.8k
    }
194
}
195
196
pub struct ByteArrayValues<T: OffsetSizeTrait> {
197
    offsets: OffsetBuffer<T>,
198
    values: Buffer,
199
}
200
201
impl<T: OffsetSizeTrait> ByteArrayValues<T> {
202
108
    fn value(&self, idx: usize) -> &[u8] {
203
108
        assert!(idx < self.len());
204
        // Safety: offsets are valid and checked bounds above
205
        unsafe {
206
108
            let start = self.offsets.get_unchecked(idx).as_usize();
207
108
            let end = self.offsets.get_unchecked(idx + 1).as_usize();
208
108
            self.values.get_unchecked(start..end)
209
108
        }
210
108
    }
211
}
212
213
impl<T: OffsetSizeTrait> CursorValues for ByteArrayValues<T> {
214
132
    fn len(&self) -> usize {
215
132
        self.offsets.len() - 1
216
132
    }
217
218
0
    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
219
0
        l.value(l_idx) == r.value(r_idx)
220
0
    }
221
222
54
    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
223
54
        l.value(l_idx).cmp(r.value(r_idx))
224
54
    }
225
}
226
227
impl<T: ByteArrayType> CursorArray for GenericByteArray<T> {
228
    type Values = ByteArrayValues<T::Offset>;
229
230
12
    fn values(&self) -> Self::Values {
231
12
        ByteArrayValues {
232
12
            offsets: self.offsets().clone(),
233
12
            values: self.values().clone(),
234
12
        }
235
12
    }
236
}
237
238
/// A collection of sorted, nullable [`CursorValues`]
239
///
240
/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
241
#[derive(Debug)]
242
pub struct ArrayValues<T: CursorValues> {
243
    values: T,
244
    // If nulls first, the first non-null index
245
    // Otherwise, the first null index
246
    null_threshold: usize,
247
    options: SortOptions,
248
}
249
250
impl<T: CursorValues> ArrayValues<T> {
251
    /// Create a new [`ArrayValues`] from the provided `values` sorted according
252
    /// to `options`.
253
    ///
254
    /// Panics if the array is empty
255
665
    pub fn new<A: CursorArray<Values = T>>(options: SortOptions, array: &A) -> Self {
256
665
        assert!(array.len() > 0, 
"Empty array passed to FieldCursor"0
);
257
665
        let null_threshold = match options.nulls_first {
258
665
            true => array.null_count(),
259
0
            false => array.len() - array.null_count(),
260
        };
261
262
665
        Self {
263
665
            values: array.values(),
264
665
            null_threshold,
265
665
            options,
266
665
        }
267
665
    }
268
269
63.8k
    fn is_null(&self, idx: usize) -> bool {
270
63.8k
        (idx < self.null_threshold) == self.options.nulls_first
271
63.8k
    }
272
}
273
274
impl<T: CursorValues> CursorValues for ArrayValues<T> {
275
14.4k
    fn len(&self) -> usize {
276
14.4k
        self.values.len()
277
14.4k
    }
278
279
7
    fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool {
280
7
        match (l.is_null(l_idx), r.is_null(r_idx)) {
281
6
            (true, true) => true,
282
1
            (false, false) => T::eq(&l.values, l_idx, &r.values, r_idx),
283
0
            _ => false,
284
        }
285
7
    }
286
287
31.9k
    fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering {
288
31.9k
        match (l.is_null(l_idx), r.is_null(r_idx)) {
289
6
            (true, true) => Ordering::Equal,
290
2
            (true, false) => match l.options.nulls_first {
291
2
                true => Ordering::Less,
292
0
                false => Ordering::Greater,
293
            },
294
4
            (false, true) => match l.options.nulls_first {
295
0
                true => Ordering::Greater,
296
4
                false => Ordering::Less,
297
            },
298
31.8k
            (false, false) => match l.options.descending {
299
800
                true => T::compare(&r.values, r_idx, &l.values, l_idx),
300
31.0k
                false => T::compare(&l.values, l_idx, &r.values, r_idx),
301
            },
302
        }
303
31.9k
    }
304
}
305
306
#[cfg(test)]
307
mod tests {
308
    use super::*;
309
310
8
    fn new_primitive(
311
8
        options: SortOptions,
312
8
        values: ScalarBuffer<i32>,
313
8
        null_count: usize,
314
8
    ) -> Cursor<ArrayValues<PrimitiveValues<i32>>> {
315
8
        let null_threshold = match options.nulls_first {
316
4
            true => null_count,
317
4
            false => values.len() - null_count,
318
        };
319
320
8
        let values = ArrayValues {
321
8
            values: PrimitiveValues(values),
322
8
            null_threshold,
323
8
            options,
324
8
        };
325
8
326
8
        Cursor::new(values)
327
8
    }
328
329
    #[test]
330
1
    fn test_primitive_nulls_first() {
331
1
        let options = SortOptions {
332
1
            descending: false,
333
1
            nulls_first: true,
334
1
        };
335
1
336
1
        let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
337
1
        let mut a = new_primitive(options, buffer, 1);
338
1
        let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
339
1
        let mut b = new_primitive(options, buffer, 2);
340
1
341
1
        // NULL == NULL
342
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
343
1
        assert_eq!(a, b);
344
345
        // NULL == NULL
346
1
        b.advance();
347
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
348
1
        assert_eq!(a, b);
349
350
        // NULL < -2
351
1
        b.advance();
352
1
        assert_eq!(a.cmp(&b), Ordering::Less);
353
354
        // 1 > -2
355
1
        a.advance();
356
1
        assert_eq!(a.cmp(&b), Ordering::Greater);
357
358
        // 1 > -1
359
1
        b.advance();
360
1
        assert_eq!(a.cmp(&b), Ordering::Greater);
361
362
        // 1 == 1
363
1
        b.advance();
364
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
365
1
        assert_eq!(a, b);
366
367
        // 9 > 1
368
1
        b.advance();
369
1
        assert_eq!(a.cmp(&b), Ordering::Less);
370
371
        // 9 > 2
372
1
        a.advance();
373
1
        assert_eq!(a.cmp(&b), Ordering::Less);
374
375
1
        let options = SortOptions {
376
1
            descending: false,
377
1
            nulls_first: false,
378
1
        };
379
1
380
1
        let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
381
1
        let mut a = new_primitive(options, buffer, 2);
382
1
        let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
383
1
        let mut b = new_primitive(options, buffer, 2);
384
1
385
1
        // 0 > -1
386
1
        assert_eq!(a.cmp(&b), Ordering::Greater);
387
388
        // 0 < NULL
389
1
        b.advance();
390
1
        assert_eq!(a.cmp(&b), Ordering::Less);
391
392
        // 1 < NULL
393
1
        a.advance();
394
1
        assert_eq!(a.cmp(&b), Ordering::Less);
395
396
        // NULL = NULL
397
1
        a.advance();
398
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
399
1
        assert_eq!(a, b);
400
401
1
        let options = SortOptions {
402
1
            descending: true,
403
1
            nulls_first: false,
404
1
        };
405
1
406
1
        let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
407
1
        let mut a = new_primitive(options, buffer, 3);
408
1
        let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
409
1
        let mut b = new_primitive(options, buffer, 2);
410
1
411
1
        // 6 > 67
412
1
        assert_eq!(a.cmp(&b), Ordering::Greater);
413
414
        // 6 < -3
415
1
        b.advance();
416
1
        assert_eq!(a.cmp(&b), Ordering::Less);
417
418
        // 6 < NULL
419
1
        b.advance();
420
1
        assert_eq!(a.cmp(&b), Ordering::Less);
421
422
        // 6 < NULL
423
1
        b.advance();
424
1
        assert_eq!(a.cmp(&b), Ordering::Less);
425
426
        // NULL == NULL
427
1
        a.advance();
428
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
429
1
        assert_eq!(a, b);
430
431
1
        let options = SortOptions {
432
1
            descending: true,
433
1
            nulls_first: true,
434
1
        };
435
1
436
1
        let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
437
1
        let mut a = new_primitive(options, buffer, 2);
438
1
        let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
439
1
        let mut b = new_primitive(options, buffer, 1);
440
1
441
1
        // NULL == NULL
442
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
443
1
        assert_eq!(a, b);
444
445
        // NULL == NULL
446
1
        a.advance();
447
1
        assert_eq!(a.cmp(&b), Ordering::Equal);
448
1
        assert_eq!(a, b);
449
450
        // NULL < 4546
451
1
        b.advance();
452
1
        assert_eq!(a.cmp(&b), Ordering::Less);
453
454
        // 6 > 4546
455
1
        a.advance();
456
1
        assert_eq!(a.cmp(&b), Ordering::Greater);
457
458
        // 6 < -3
459
1
        b.advance();
460
1
        assert_eq!(a.cmp(&b), Ordering::Less);
461
1
    }
462
}