Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/coalesce/mod.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use arrow::compute::concat_batches;
19
use arrow_array::builder::StringViewBuilder;
20
use arrow_array::cast::AsArray;
21
use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions};
22
use arrow_schema::SchemaRef;
23
use std::sync::Arc;
24
25
/// Concatenate multiple [`RecordBatch`]es
26
///
27
/// `BatchCoalescer` concatenates multiple small [`RecordBatch`]es, produced by
28
/// operations such as `FilterExec` and `RepartitionExec`, into larger ones for
29
/// more efficient processing by subsequent operations.
30
///
31
/// # Background
32
///
33
/// Generally speaking, larger [`RecordBatch`]es are more efficient to process
34
/// than smaller record batches (until the CPU cache is exceeded) because there
35
/// is fixed processing overhead per batch. DataFusion tries to operate on
36
/// batches of `target_batch_size` rows to amortize this overhead
37
///
38
/// ```text
39
/// ┌────────────────────┐
40
/// │    RecordBatch     │
41
/// │   num_rows = 23    │
42
/// └────────────────────┘                 ┌────────────────────┐
43
///                                        │                    │
44
/// ┌────────────────────┐     Coalesce    │                    │
45
/// │                    │      Batches    │                    │
46
/// │    RecordBatch     │                 │                    │
47
/// │   num_rows = 50    │  ─ ─ ─ ─ ─ ─ ▶  │                    │
48
/// │                    │                 │    RecordBatch     │
49
/// │                    │                 │   num_rows = 106   │
50
/// └────────────────────┘                 │                    │
51
///                                        │                    │
52
/// ┌────────────────────┐                 │                    │
53
/// │                    │                 │                    │
54
/// │    RecordBatch     │                 │                    │
55
/// │   num_rows = 33    │                 └────────────────────┘
56
/// │                    │
57
/// └────────────────────┘
58
/// ```
59
///
60
/// # Notes:
61
///
62
/// 1. Output rows are produced in the same order as the input rows
63
///
64
/// 2. The output is a sequence of batches, with all but the last being at least
65
///    `target_batch_size` rows.
66
///
67
/// 3. Eventually this may also be able to handle other optimizations such as a
68
///    combined filter/coalesce operation.
69
///
70
#[derive(Debug)]
71
pub struct BatchCoalescer {
72
    /// The input schema
73
    schema: SchemaRef,
74
    /// Minimum number of rows for coalesces batches
75
    target_batch_size: usize,
76
    /// Total number of rows returned so far
77
    total_rows: usize,
78
    /// Buffered batches
79
    buffer: Vec<RecordBatch>,
80
    /// Buffered row count
81
    buffered_rows: usize,
82
    /// Limit: maximum number of rows to fetch, `None` means fetch all rows
83
    fetch: Option<usize>,
84
}
85
86
impl BatchCoalescer {
87
    /// Create a new `BatchCoalescer`
88
    ///
89
    /// # Arguments
90
    /// - `schema` - the schema of the output batches
91
    /// - `target_batch_size` - the minimum number of rows for each
92
    ///    output batch (until limit reached)
93
    /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows
94
10
    pub fn new(
95
10
        schema: SchemaRef,
96
10
        target_batch_size: usize,
97
10
        fetch: Option<usize>,
98
10
    ) -> Self {
99
10
        Self {
100
10
            schema,
101
10
            target_batch_size,
102
10
            total_rows: 0,
103
10
            buffer: vec![],
104
10
            buffered_rows: 0,
105
10
            fetch,
106
10
        }
107
10
    }
108
109
    /// Return the schema of the output batches
110
0
    pub fn schema(&self) -> SchemaRef {
111
0
        Arc::clone(&self.schema)
112
0
    }
113
114
    /// Push next batch, and returns [`CoalescerState`] indicating the current
115
    /// state of the buffer.
116
60
    pub fn push_batch(&mut self, batch: RecordBatch) -> CoalescerState {
117
60
        let batch = gc_string_view_batch(&batch);
118
60
        if self.limit_reached(&batch) {
119
4
            CoalescerState::LimitReached
120
56
        } else if self.target_reached(batch) {
121
9
            CoalescerState::TargetReached
122
        } else {
123
47
            CoalescerState::Continue
124
        }
125
60
    }
126
127
    /// Return true if the there is no data buffered
128
8
    pub fn is_empty(&self) -> bool {
129
8
        self.buffer.is_empty()
130
8
    }
131
132
    /// Checks if the buffer will reach the specified limit after getting
133
    /// `batch`.
134
    ///
135
    /// If fetch would be exceeded, slices the received batch, updates the
136
    /// buffer with it, and returns `true`.
137
    ///
138
    /// Otherwise: does nothing and returns `false`.
139
60
    fn limit_reached(&mut self, batch: &RecordBatch) -> bool {
140
26
        match self.fetch {
141
26
            Some(
fetch4
) if self.total_rows + batch.num_rows() >= fetch => {
142
4
                // Limit is reached
143
4
                let remaining_rows = fetch - self.total_rows;
144
4
                debug_assert!(remaining_rows > 0);
145
146
4
                let batch = batch.slice(0, remaining_rows);
147
4
                self.buffered_rows += batch.num_rows();
148
4
                self.total_rows = fetch;
149
4
                self.buffer.push(batch);
150
4
                true
151
            }
152
56
            _ => false,
153
        }
154
60
    }
155
156
    /// Updates the buffer with the given batch.
157
    ///
158
    /// If the target batch size is reached, returns `true`. Otherwise, returns
159
    /// `false`.
160
56
    fn target_reached(&mut self, batch: RecordBatch) -> bool {
161
56
        if batch.num_rows() == 0 {
162
0
            false
163
        } else {
164
56
            self.total_rows += batch.num_rows();
165
56
            self.buffered_rows += batch.num_rows();
166
56
            self.buffer.push(batch);
167
56
            self.buffered_rows >= self.target_batch_size
168
        }
169
56
    }
170
171
    /// Concatenates and returns all buffered batches, and clears the buffer.
172
17
    pub fn finish_batch(&mut self) -> datafusion_common::Result<RecordBatch> {
173
17
        let batch = concat_batches(&self.schema, &self.buffer)
?0
;
174
17
        self.buffer.clear();
175
17
        self.buffered_rows = 0;
176
17
        Ok(batch)
177
17
    }
178
}
179
180
/// Indicates the state of the [`BatchCoalescer`] buffer after the
181
/// [`BatchCoalescer::push_batch()`] operation.
182
///
183
/// The caller should take diferent actions, depending on the variant returned.
184
pub enum CoalescerState {
185
    /// Neither the limit nor the target batch size is reached.
186
    ///
187
    /// Action: continue pushing batches.
188
    Continue,
189
    /// The limit has been reached.
190
    ///
191
    /// Action: call [`BatchCoalescer::finish_batch()`] to get the final
192
    /// buffered results as a batch and finish the query.
193
    LimitReached,
194
    /// The specified minimum number of rows a batch should have is reached.
195
    ///
196
    /// Action: call [`BatchCoalescer::finish_batch()`] to get the current
197
    /// buffered results as a batch and then continue pushing batches.
198
    TargetReached,
199
}
200
201
/// Heuristically compact `StringViewArray`s to reduce memory usage, if needed
202
///
203
/// Decides when to consolidate the StringView into a new buffer to reduce
204
/// memory usage and improve string locality for better performance.
205
///
206
/// This differs from `StringViewArray::gc` because:
207
/// 1. It may not compact the array depending on a heuristic.
208
/// 2. It uses a precise block size to reduce the number of buffers to track.
209
///
210
/// # Heuristic
211
///
212
/// If the average size of each view is larger than 32 bytes, we compact the array.
213
///
214
/// `StringViewArray` include pointers to buffer that hold the underlying data.
215
/// One of the great benefits of `StringViewArray` is that many operations
216
/// (e.g., `filter`) can be done without copying the underlying data.
217
///
218
/// However, after a while (e.g., after `FilterExec` or `HashJoinExec`) the
219
/// `StringViewArray` may only refer to a small portion of the buffer,
220
/// significantly increasing memory usage.
221
64
fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch {
222
64
    let new_columns: Vec<ArrayRef> = batch
223
64
        .columns()
224
64
        .iter()
225
135
        .map(|c| {
226
            // Try to re-create the `StringViewArray` to prevent holding the underlying buffer too long.
227
135
            let Some(
s3
) = c.as_string_view_opt() else {
228
132
                return Arc::clone(c);
229
            };
230
3
            let ideal_buffer_size: usize = s
231
3
                .views()
232
3
                .iter()
233
2.02k
                .map(|v| {
234
2.02k
                    let len = (*v as u32) as usize;
235
2.02k
                    if len > 12 {
236
1.02k
                        len
237
                    } else {
238
1.00k
                        0
239
                    }
240
2.02k
                })
241
3
                .sum();
242
3
            let actual_buffer_size = s.get_buffer_memory_size();
243
3
244
3
            // Re-creating the array copies data and can be time consuming.
245
3
            // We only do it if the array is sparse
246
3
            if actual_buffer_size > (ideal_buffer_size * 2) {
247
                // We set the block size to `ideal_buffer_size` so that the new StringViewArray only has one buffer, which accelerate later concat_batches.
248
                // See https://github.com/apache/arrow-rs/issues/6094 for more details.
249
2
                let mut builder = StringViewBuilder::with_capacity(s.len());
250
2
                if ideal_buffer_size > 0 {
251
1
                    builder = builder.with_fixed_block_size(ideal_buffer_size as u32);
252
1
                }
253
254
1.02k
                for v in 
s.iter()2
{
255
1.02k
                    builder.append_option(v);
256
1.02k
                }
257
258
2
                let gc_string = builder.finish();
259
2
260
2
                debug_assert!(gc_string.data_buffers().len() <= 1); // buffer count can be 0 if the `ideal_buffer_size` is 0
261
262
2
                Arc::new(gc_string)
263
            } else {
264
1
                Arc::clone(c)
265
            }
266
135
        })
267
64
        .collect();
268
64
    let mut options = RecordBatchOptions::new();
269
64
    options = options.with_row_count(Some(batch.num_rows()));
270
64
    RecordBatch::try_new_with_options(batch.schema(), new_columns, &options)
271
64
        .expect("Failed to re-create the gc'ed record batch")
272
64
}
273
274
#[cfg(test)]
275
mod tests {
276
    use std::ops::Range;
277
278
    use super::*;
279
280
    use arrow::datatypes::{DataType, Field, Schema};
281
    use arrow_array::builder::ArrayBuilder;
282
    use arrow_array::{StringViewArray, UInt32Array};
283
284
    #[test]
285
1
    fn test_coalesce() {
286
1
        let batch = uint32_batch(0..8);
287
1
        Test::new()
288
1
            .with_batches(std::iter::repeat(batch).take(10))
289
1
            // expected output is batches of at least 20 rows (except for the final batch)
290
1
            .with_target_batch_size(21)
291
1
            .with_expected_output_sizes(vec![24, 24, 24, 8])
292
1
            .run()
293
1
    }
294
295
    #[test]
296
1
    fn test_coalesce_with_fetch_larger_than_input_size() {
297
1
        let batch = uint32_batch(0..8);
298
1
        Test::new()
299
1
            .with_batches(std::iter::repeat(batch).take(10))
300
1
            // input is 10 batches x 8 rows (80 rows) with fetch limit of 100
301
1
            // expected to behave the same as `test_concat_batches`
302
1
            .with_target_batch_size(21)
303
1
            .with_fetch(Some(100))
304
1
            .with_expected_output_sizes(vec![24, 24, 24, 8])
305
1
            .run();
306
1
    }
307
308
    #[test]
309
1
    fn test_coalesce_with_fetch_less_than_input_size() {
310
1
        let batch = uint32_batch(0..8);
311
1
        Test::new()
312
1
            .with_batches(std::iter::repeat(batch).take(10))
313
1
            // input is 10 batches x 8 rows (80 rows) with fetch limit of 50
314
1
            .with_target_batch_size(21)
315
1
            .with_fetch(Some(50))
316
1
            .with_expected_output_sizes(vec![24, 24, 2])
317
1
            .run();
318
1
    }
319
320
    #[test]
321
1
    fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() {
322
1
        let batch = uint32_batch(0..8);
323
1
        Test::new()
324
1
            .with_batches(std::iter::repeat(batch).take(10))
325
1
            // input is 10 batches x 8 rows (80 rows) with fetch limit of 48
326
1
            .with_target_batch_size(21)
327
1
            .with_fetch(Some(48))
328
1
            .with_expected_output_sizes(vec![24, 24])
329
1
            .run();
330
1
    }
331
332
    #[test]
333
1
    fn test_coalesce_with_fetch_less_target_batch_size() {
334
1
        let batch = uint32_batch(0..8);
335
1
        Test::new()
336
1
            .with_batches(std::iter::repeat(batch).take(10))
337
1
            // input is 10 batches x 8 rows (80 rows) with fetch limit of 10
338
1
            .with_target_batch_size(21)
339
1
            .with_fetch(Some(10))
340
1
            .with_expected_output_sizes(vec![10])
341
1
            .run();
342
1
    }
343
344
    #[test]
345
1
    fn test_coalesce_single_large_batch_over_fetch() {
346
1
        let large_batch = uint32_batch(0..100);
347
1
        Test::new()
348
1
            .with_batch(large_batch)
349
1
            .with_target_batch_size(20)
350
1
            .with_fetch(Some(7))
351
1
            .with_expected_output_sizes(vec![7])
352
1
            .run()
353
1
    }
354
355
    /// Test for [`BatchCoalescer`]
356
    ///
357
    /// Pushes the input batches to the coalescer and verifies that the resulting
358
    /// batches have the expected number of rows and contents.
359
    #[derive(Debug, Clone, Default)]
360
    struct Test {
361
        /// Batches to feed to the coalescer. Tests must have at least one
362
        /// schema
363
        input_batches: Vec<RecordBatch>,
364
        /// Expected output sizes of the resulting batches
365
        expected_output_sizes: Vec<usize>,
366
        /// target batch size
367
        target_batch_size: usize,
368
        /// Fetch (limit)
369
        fetch: Option<usize>,
370
    }
371
372
    impl Test {
373
6
        fn new() -> Self {
374
6
            Self::default()
375
6
        }
376
377
        /// Set the target batch size
378
6
        fn with_target_batch_size(mut self, target_batch_size: usize) -> Self {
379
6
            self.target_batch_size = target_batch_size;
380
6
            self
381
6
        }
382
383
        /// Set the fetch (limit)
384
5
        fn with_fetch(mut self, fetch: Option<usize>) -> Self {
385
5
            self.fetch = fetch;
386
5
            self
387
5
        }
388
389
        /// Extend the input batches with `batch`
390
1
        fn with_batch(mut self, batch: RecordBatch) -> Self {
391
1
            self.input_batches.push(batch);
392
1
            self
393
1
        }
394
395
        /// Extends the input batches with `batches`
396
5
        fn with_batches(
397
5
            mut self,
398
5
            batches: impl IntoIterator<Item = RecordBatch>,
399
5
        ) -> Self {
400
5
            self.input_batches.extend(batches);
401
5
            self
402
5
        }
403
404
        /// Extends `sizes` to expected output sizes
405
6
        fn with_expected_output_sizes(
406
6
            mut self,
407
6
            sizes: impl IntoIterator<Item = usize>,
408
6
        ) -> Self {
409
6
            self.expected_output_sizes.extend(sizes);
410
6
            self
411
6
        }
412
413
        /// Runs the test -- see documentation on [`Test`] for details
414
6
        fn run(self) {
415
6
            let Self {
416
6
                input_batches,
417
6
                target_batch_size,
418
6
                fetch,
419
6
                expected_output_sizes,
420
6
            } = self;
421
6
422
6
            let schema = input_batches[0].schema();
423
6
424
6
            // create a single large input batch for output comparison
425
6
            let single_input_batch = concat_batches(&schema, &input_batches).unwrap();
426
6
427
6
            let mut coalescer =
428
6
                BatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch);
429
6
430
6
            let mut output_batches = vec![];
431
38
            for 
batch36
in input_batches {
432
36
                match coalescer.push_batch(batch) {
433
23
                    CoalescerState::Continue => {}
434
                    CoalescerState::LimitReached => {
435
4
                        output_batches.push(coalescer.finish_batch().unwrap());
436
4
                        break;
437
                    }
438
9
                    CoalescerState::TargetReached => {
439
9
                        coalescer.buffered_rows = 0;
440
9
                        output_batches.push(coalescer.finish_batch().unwrap());
441
9
                    }
442
                }
443
            }
444
6
            if coalescer.buffered_rows != 0 {
445
2
                output_batches.extend(coalescer.buffer);
446
4
            }
447
448
            // make sure we got the expected number of output batches and content
449
6
            let mut starting_idx = 0;
450
6
            assert_eq!(expected_output_sizes.len(), output_batches.len());
451
15
            for (i, (expected_size, batch)) in
452
6
                expected_output_sizes.iter().zip(output_batches).enumerate()
453
            {
454
15
                assert_eq!(
455
15
                    *expected_size,
456
15
                    batch.num_rows(),
457
0
                    "Unexpected number of rows in Batch {i}"
458
                );
459
460
                // compare the contents of the batch (using `==` compares the
461
                // underlying memory layout too)
462
15
                let expected_batch =
463
15
                    single_input_batch.slice(starting_idx, *expected_size);
464
15
                let batch_strings = batch_to_pretty_strings(&batch);
465
15
                let expected_batch_strings = batch_to_pretty_strings(&expected_batch);
466
15
                let batch_strings = batch_strings.lines().collect::<Vec<_>>();
467
15
                let expected_batch_strings =
468
15
                    expected_batch_strings.lines().collect::<Vec<_>>();
469
15
                assert_eq!(
470
                    expected_batch_strings, batch_strings,
471
0
                    "Unexpected content in Batch {i}:\
472
0
                    \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}"
473
                );
474
15
                starting_idx += *expected_size;
475
            }
476
6
        }
477
    }
478
479
    /// Return a batch of  UInt32 with the specified range
480
6
    fn uint32_batch(range: Range<u32>) -> RecordBatch {
481
6
        let schema =
482
6
            Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
483
6
484
6
        RecordBatch::try_new(
485
6
            Arc::clone(&schema),
486
6
            vec![Arc::new(UInt32Array::from_iter_values(range))],
487
6
        )
488
6
        .unwrap()
489
6
    }
490
491
    #[test]
492
1
    fn test_gc_string_view_batch_small_no_compact() {
493
1
        // view with only short strings (no buffers) --> no need to compact
494
1
        let array = StringViewTest {
495
1
            rows: 1000,
496
1
            strings: vec![Some("a"), Some("b"), Some("c")],
497
1
        }
498
1
        .build();
499
1
500
1
        let gc_array = do_gc(array.clone());
501
1
        compare_string_array_values(&array, &gc_array);
502
1
        assert_eq!(array.data_buffers().len(), 0);
503
1
        assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction
504
1
    }
505
506
    #[test]
507
1
    fn test_gc_string_view_test_batch_empty() {
508
1
        let schema = Schema::empty();
509
1
        let batch = RecordBatch::new_empty(schema.into());
510
1
        let output_batch = gc_string_view_batch(&batch);
511
1
        assert_eq!(batch.num_columns(), output_batch.num_columns());
512
1
        assert_eq!(batch.num_rows(), output_batch.num_rows());
513
1
    }
514
515
    #[test]
516
1
    fn test_gc_string_view_batch_large_no_compact() {
517
1
        // view with large strings (has buffers) but full --> no need to compact
518
1
        let array = StringViewTest {
519
1
            rows: 1000,
520
1
            strings: vec![Some("This string is longer than 12 bytes")],
521
1
        }
522
1
        .build();
523
1
524
1
        let gc_array = do_gc(array.clone());
525
1
        compare_string_array_values(&array, &gc_array);
526
1
        assert_eq!(array.data_buffers().len(), 5);
527
1
        assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction
528
1
    }
529
530
    #[test]
531
1
    fn test_gc_string_view_batch_large_slice_compact() {
532
1
        // view with large strings (has buffers) and only partially used  --> no need to compact
533
1
        let array = StringViewTest {
534
1
            rows: 1000,
535
1
            strings: vec![Some("this string is longer than 12 bytes")],
536
1
        }
537
1
        .build();
538
1
539
1
        // slice only 11 rows, so most of the buffer is not used
540
1
        let array = array.slice(11, 22);
541
1
542
1
        let gc_array = do_gc(array.clone());
543
1
        compare_string_array_values(&array, &gc_array);
544
1
        assert_eq!(array.data_buffers().len(), 5);
545
1
        assert_eq!(gc_array.data_buffers().len(), 1); // compacted into a single buffer
546
1
    }
547
548
    /// Compares the values of two string view arrays
549
3
    fn compare_string_array_values(arr1: &StringViewArray, arr2: &StringViewArray) {
550
3
        assert_eq!(arr1.len(), arr2.len());
551
2.02k
        for (s1, s2) in 
arr1.iter().zip(arr2.iter())3
{
552
2.02k
            assert_eq!(s1, s2);
553
        }
554
3
    }
555
556
    /// runs garbage collection on string view array
557
    /// and ensures the number of rows are the same
558
3
    fn do_gc(array: StringViewArray) -> StringViewArray {
559
3
        let batch =
560
3
            RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]).unwrap();
561
3
        let gc_batch = gc_string_view_batch(&batch);
562
3
        assert_eq!(batch.num_rows(), gc_batch.num_rows());
563
3
        assert_eq!(batch.schema(), gc_batch.schema());
564
3
        gc_batch
565
3
            .column(0)
566
3
            .as_any()
567
3
            .downcast_ref::<StringViewArray>()
568
3
            .unwrap()
569
3
            .clone()
570
3
    }
571
572
    /// Describes parameters for creating a `StringViewArray`
573
    struct StringViewTest {
574
        /// The number of rows in the array
575
        rows: usize,
576
        /// The strings to use in the array (repeated over and over
577
        strings: Vec<Option<&'static str>>,
578
    }
579
580
    impl StringViewTest {
581
        /// Create a `StringViewArray` with the parameters specified in this struct
582
3
        fn build(self) -> StringViewArray {
583
3
            let mut builder =
584
3
                StringViewBuilder::with_capacity(100).with_fixed_block_size(8192);
585
            loop {
586
3.00k
                for &v in 
self.strings.iter()2.33k
{
587
3.00k
                    builder.append_option(v);
588
3.00k
                    if builder.len() >= self.rows {
589
3
                        return builder.finish();
590
2.99k
                    }
591
                }
592
            }
593
3
        }
594
    }
595
30
    fn batch_to_pretty_strings(batch: &RecordBatch) -> String {
596
30
        arrow::util::pretty::pretty_format_batches(&[batch.clone()])
597
30
            .unwrap()
598
30
            .to_string()
599
30
    }
600
}