Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/cross_join.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines the cross join plan for loading the left side of the cross join
19
//! and producing batches in parallel for the right partitions
20
21
use super::utils::{
22
    adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut,
23
    StatefulStreamResult,
24
};
25
use crate::coalesce_partitions::CoalescePartitionsExec;
26
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
27
use crate::{
28
    execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs,
29
    DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan,
30
    ExecutionPlanProperties, PlanProperties, RecordBatchStream,
31
    SendableRecordBatchStream, Statistics,
32
};
33
use arrow::compute::concat_batches;
34
use std::{any::Any, sync::Arc, task::Poll};
35
36
use arrow::datatypes::{Fields, Schema, SchemaRef};
37
use arrow::record_batch::RecordBatch;
38
use arrow_array::RecordBatchOptions;
39
use datafusion_common::stats::Precision;
40
use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
41
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
42
use datafusion_execution::TaskContext;
43
use datafusion_physical_expr::equivalence::join_equivalence_properties;
44
45
use async_trait::async_trait;
46
use futures::{ready, Stream, StreamExt, TryStreamExt};
47
48
/// Data of the left side
49
type JoinLeftData = (RecordBatch, MemoryReservation);
50
51
/// executes partitions in parallel and combines them into a set of
52
/// partitions by combining all values from the left with all values on the right
53
#[derive(Debug)]
54
pub struct CrossJoinExec {
55
    /// left (build) side which gets loaded in memory
56
    pub left: Arc<dyn ExecutionPlan>,
57
    /// right (probe) side which are combined with left side
58
    pub right: Arc<dyn ExecutionPlan>,
59
    /// The schema once the join is applied
60
    schema: SchemaRef,
61
    /// Build-side data
62
    left_fut: OnceAsync<JoinLeftData>,
63
    /// Execution plan metrics
64
    metrics: ExecutionPlanMetricsSet,
65
    cache: PlanProperties,
66
}
67
68
impl CrossJoinExec {
69
    /// Create a new [CrossJoinExec].
70
2
    pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self {
71
2
        // left then right
72
2
        let (all_columns, metadata) = {
73
2
            let left_schema = left.schema();
74
2
            let right_schema = right.schema();
75
2
            let left_fields = left_schema.fields().iter();
76
2
            let right_fields = right_schema.fields().iter();
77
2
78
2
            let mut metadata = left_schema.metadata().clone();
79
2
            metadata.extend(right_schema.metadata().clone());
80
2
81
2
            (
82
2
                left_fields.chain(right_fields).cloned().collect::<Fields>(),
83
2
                metadata,
84
2
            )
85
2
        };
86
2
87
2
        let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata));
88
2
        let cache = Self::compute_properties(&left, &right, Arc::clone(&schema));
89
2
        CrossJoinExec {
90
2
            left,
91
2
            right,
92
2
            schema,
93
2
            left_fut: Default::default(),
94
2
            metrics: ExecutionPlanMetricsSet::default(),
95
2
            cache,
96
2
        }
97
2
    }
98
99
    /// left (build) side which gets loaded in memory
100
2
    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
101
2
        &self.left
102
2
    }
103
104
    /// right side which gets combined with left side
105
0
    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
106
0
        &self.right
107
0
    }
108
109
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
110
2
    fn compute_properties(
111
2
        left: &Arc<dyn ExecutionPlan>,
112
2
        right: &Arc<dyn ExecutionPlan>,
113
2
        schema: SchemaRef,
114
2
    ) -> PlanProperties {
115
2
        // Calculate equivalence properties
116
2
        // TODO: Check equivalence properties of cross join, it may preserve
117
2
        //       ordering in some cases.
118
2
        let eq_properties = join_equivalence_properties(
119
2
            left.equivalence_properties().clone(),
120
2
            right.equivalence_properties().clone(),
121
2
            &JoinType::Full,
122
2
            schema,
123
2
            &[false, false],
124
2
            None,
125
2
            &[],
126
2
        );
127
2
128
2
        // Get output partitioning:
129
2
        // TODO: Optimize the cross join implementation to generate M * N
130
2
        //       partitions.
131
2
        let output_partitioning = adjust_right_output_partitioning(
132
2
            right.output_partitioning(),
133
2
            left.schema().fields.len(),
134
2
        );
135
2
136
2
        // Determine the execution mode:
137
2
        let mut mode = execution_mode_from_children([left, right]);
138
2
        if mode.is_unbounded() {
139
0
            // If any of the inputs is unbounded, cross join breaks the pipeline.
140
0
            mode = ExecutionMode::PipelineBreaking;
141
2
        }
142
143
2
        PlanProperties::new(eq_properties, output_partitioning, mode)
144
2
    }
145
}
146
147
/// Asynchronously collect the result of the left child
148
2
async fn load_left_input(
149
2
    left: Arc<dyn ExecutionPlan>,
150
2
    context: Arc<TaskContext>,
151
2
    metrics: BuildProbeJoinMetrics,
152
2
    reservation: MemoryReservation,
153
2
) -> Result<JoinLeftData> {
154
2
    // merge all left parts into a single stream
155
2
    let left_schema = left.schema();
156
2
    let merge = if left.output_partitioning().partition_count() != 1 {
157
0
        Arc::new(CoalescePartitionsExec::new(left))
158
    } else {
159
2
        left
160
    };
161
2
    let stream = merge.execute(0, context)
?0
;
162
163
    // Load all batches and count the rows
164
2
    let (
batches, _metrics, reservation1
) = stream
165
2
        .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async {
166
2
            let batch_size = batch.get_array_memory_size();
167
2
            // Reserve memory for incoming batch
168
2
            acc.2.try_grow(batch_size)
?1
;
169
            // Update metrics
170
1
            acc.1.build_mem_used.add(batch_size);
171
1
            acc.1.build_input_batches.add(1);
172
1
            acc.1.build_input_rows.add(batch.num_rows());
173
1
            // Push batch to output
174
1
            acc.0.push(batch);
175
1
            Ok(acc)
176
4
        }
)2
177
1
        .
await0
?;
178
179
1
    let merged_batch = concat_batches(&left_schema, &batches)
?0
;
180
181
1
    Ok((merged_batch, reservation))
182
2
}
183
184
impl DisplayAs for CrossJoinExec {
185
0
    fn fmt_as(
186
0
        &self,
187
0
        t: DisplayFormatType,
188
0
        f: &mut std::fmt::Formatter,
189
0
    ) -> std::fmt::Result {
190
0
        match t {
191
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
192
0
                write!(f, "CrossJoinExec")
193
0
            }
194
0
        }
195
0
    }
196
}
197
198
impl ExecutionPlan for CrossJoinExec {
199
0
    fn name(&self) -> &'static str {
200
0
        "CrossJoinExec"
201
0
    }
202
203
0
    fn as_any(&self) -> &dyn Any {
204
0
        self
205
0
    }
206
207
2
    fn properties(&self) -> &PlanProperties {
208
2
        &self.cache
209
2
    }
210
211
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
212
0
        vec![&self.left, &self.right]
213
0
    }
214
215
0
    fn metrics(&self) -> Option<MetricsSet> {
216
0
        Some(self.metrics.clone_inner())
217
0
    }
218
219
0
    fn with_new_children(
220
0
        self: Arc<Self>,
221
0
        children: Vec<Arc<dyn ExecutionPlan>>,
222
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
223
0
        Ok(Arc::new(CrossJoinExec::new(
224
0
            Arc::clone(&children[0]),
225
0
            Arc::clone(&children[1]),
226
0
        )))
227
0
    }
228
229
0
    fn required_input_distribution(&self) -> Vec<Distribution> {
230
0
        vec![
231
0
            Distribution::SinglePartition,
232
0
            Distribution::UnspecifiedDistribution,
233
0
        ]
234
0
    }
235
236
2
    fn execute(
237
2
        &self,
238
2
        partition: usize,
239
2
        context: Arc<TaskContext>,
240
2
    ) -> Result<SendableRecordBatchStream> {
241
2
        let stream = self.right.execute(partition, Arc::clone(&context))
?0
;
242
243
2
        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
244
2
245
2
        // Initialization of operator-level reservation
246
2
        let reservation =
247
2
            MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
248
2
249
2
        let left_fut = self.left_fut.once(|| {
250
2
            load_left_input(
251
2
                Arc::clone(&self.left),
252
2
                context,
253
2
                join_metrics.clone(),
254
2
                reservation,
255
2
            )
256
2
        });
257
2
258
2
        Ok(Box::pin(CrossJoinStream {
259
2
            schema: Arc::clone(&self.schema),
260
2
            left_fut,
261
2
            right: stream,
262
2
            left_index: 0,
263
2
            join_metrics,
264
2
            state: CrossJoinStreamState::WaitBuildSide,
265
2
            left_data: RecordBatch::new_empty(self.left().schema()),
266
2
        }))
267
2
    }
268
269
0
    fn statistics(&self) -> Result<Statistics> {
270
0
        Ok(stats_cartesian_product(
271
0
            self.left.statistics()?,
272
0
            self.right.statistics()?,
273
        ))
274
0
    }
275
}
276
277
/// [left/right]_col_count are required in case the column statistics are None
278
2
fn stats_cartesian_product(
279
2
    left_stats: Statistics,
280
2
    right_stats: Statistics,
281
2
) -> Statistics {
282
2
    let left_row_count = left_stats.num_rows;
283
2
    let right_row_count = right_stats.num_rows;
284
2
285
2
    // calculate global stats
286
2
    let num_rows = left_row_count.multiply(&right_row_count);
287
2
    // the result size is two times a*b because you have the columns of both left and right
288
2
    let total_byte_size = left_stats
289
2
        .total_byte_size
290
2
        .multiply(&right_stats.total_byte_size)
291
2
        .multiply(&Precision::Exact(2));
292
2
293
2
    let left_col_stats = left_stats.column_statistics;
294
2
    let right_col_stats = right_stats.column_statistics;
295
2
296
2
    // the null counts must be multiplied by the row counts of the other side (if defined)
297
2
    // Min, max and distinct_count on the other hand are invariants.
298
2
    let cross_join_stats = left_col_stats
299
2
        .into_iter()
300
4
        .map(|s| ColumnStatistics {
301
4
            null_count: s.null_count.multiply(&right_row_count),
302
4
            distinct_count: s.distinct_count,
303
4
            min_value: s.min_value,
304
4
            max_value: s.max_value,
305
4
        })
306
2
        .chain(right_col_stats.into_iter().map(|s| ColumnStatistics {
307
2
            null_count: s.null_count.multiply(&left_row_count),
308
2
            distinct_count: s.distinct_count,
309
2
            min_value: s.min_value,
310
2
            max_value: s.max_value,
311
2
        }))
312
2
        .collect();
313
2
314
2
    Statistics {
315
2
        num_rows,
316
2
        total_byte_size,
317
2
        column_statistics: cross_join_stats,
318
2
    }
319
2
}
320
321
/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
322
struct CrossJoinStream {
323
    /// Input schema
324
    schema: Arc<Schema>,
325
    /// Future for data from left side
326
    left_fut: OnceFut<JoinLeftData>,
327
    /// Right side stream
328
    right: SendableRecordBatchStream,
329
    /// Current value on the left
330
    left_index: usize,
331
    /// Join execution metrics
332
    join_metrics: BuildProbeJoinMetrics,
333
    /// State of the stream
334
    state: CrossJoinStreamState,
335
    /// Left data
336
    left_data: RecordBatch,
337
}
338
339
impl RecordBatchStream for CrossJoinStream {
340
0
    fn schema(&self) -> SchemaRef {
341
0
        Arc::clone(&self.schema)
342
0
    }
343
}
344
345
/// Represents states of CrossJoinStream
346
enum CrossJoinStreamState {
347
    WaitBuildSide,
348
    FetchProbeBatch,
349
    /// Holds the currently processed right side batch
350
    BuildBatches(RecordBatch),
351
}
352
353
impl CrossJoinStreamState {
354
    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
355
    /// Returns an error if state is not BuildBatches state.
356
4
    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
357
4
        match self {
358
4
            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
359
0
            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
360
        }
361
4
    }
362
}
363
364
3
fn build_batch(
365
3
    left_index: usize,
366
3
    batch: &RecordBatch,
367
3
    left_data: &RecordBatch,
368
3
    schema: &Schema,
369
3
) -> Result<RecordBatch> {
370
    // Repeat value on the left n times
371
3
    let arrays = left_data
372
3
        .columns()
373
3
        .iter()
374
9
        .map(|arr| {
375
9
            let scalar = ScalarValue::try_from_array(arr, left_index)
?0
;
376
9
            scalar.to_array_of_size(batch.num_rows())
377
9
        })
378
3
        .collect::<Result<Vec<_>>>()
?0
;
379
380
3
    RecordBatch::try_new_with_options(
381
3
        Arc::new(schema.clone()),
382
3
        arrays
383
3
            .iter()
384
3
            .chain(batch.columns().iter())
385
3
            .cloned()
386
3
            .collect(),
387
3
        &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())),
388
3
    )
389
3
    .map_err(Into::into)
390
3
}
391
392
#[async_trait]
393
impl Stream for CrossJoinStream {
394
    type Item = Result<RecordBatch>;
395
396
5
    fn poll_next(
397
5
        mut self: std::pin::Pin<&mut Self>,
398
5
        cx: &mut std::task::Context<'_>,
399
5
    ) -> std::task::Poll<Option<Self::Item>> {
400
5
        self.poll_next_impl(cx)
401
5
    }
402
}
403
404
impl CrossJoinStream {
405
    /// Separate implementation function that unpins the [`CrossJoinStream`] so
406
    /// that partial borrows work correctly
407
5
    fn poll_next_impl(
408
5
        &mut self,
409
5
        cx: &mut std::task::Context<'_>,
410
5
    ) -> std::task::Poll<Option<Result<RecordBatch>>> {
411
        loop {
412
8
            return match self.state {
413
                CrossJoinStreamState::WaitBuildSide => {
414
2
                    
handle_state!1
(
ready!0
(self.collect_build_side(cx)))
415
                }
416
                CrossJoinStreamState::FetchProbeBatch => {
417
2
                    
handle_state!0
(
ready!0
(self.fetch_probe_batch(cx)))
418
                }
419
                CrossJoinStreamState::BuildBatches(_) => {
420
4
                    
handle_state!0
(self.build_batches())
421
                }
422
            };
423
        }
424
5
    }
425
426
    /// Collects build (left) side of the join into the state. In case of an empty build batch,
427
    /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch.
428
2
    fn collect_build_side(
429
2
        &mut self,
430
2
        cx: &mut std::task::Context<'_>,
431
2
    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
432
2
        let build_timer = self.join_metrics.build_time.timer();
433
2
        let (
left_data1
, _) = match
ready!0
(self.left_fut.get(cx)) {
434
1
            Ok(left_data) => left_data,
435
1
            Err(e) => return Poll::Ready(Err(e)),
436
        };
437
1
        build_timer.done();
438
439
1
        let result = if left_data.num_rows() == 0 {
440
0
            StatefulStreamResult::Ready(None)
441
        } else {
442
1
            self.left_data = left_data.clone();
443
1
            self.state = CrossJoinStreamState::FetchProbeBatch;
444
1
            StatefulStreamResult::Continue
445
        };
446
1
        Poll::Ready(Ok(result))
447
2
    }
448
449
    /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state.
450
    /// Then, the state is updated to build result batches.
451
2
    fn fetch_probe_batch(
452
2
        &mut self,
453
2
        cx: &mut std::task::Context<'_>,
454
2
    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
455
2
        self.left_index = 0;
456
2
        let 
right_data1
= match
ready!0
(self.right.poll_next_unpin(cx)) {
457
1
            Some(Ok(right_data)) => right_data,
458
0
            Some(Err(e)) => return Poll::Ready(Err(e)),
459
1
            None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
460
        };
461
1
        self.join_metrics.input_batches.add(1);
462
1
        self.join_metrics.input_rows.add(right_data.num_rows());
463
1
464
1
        self.state = CrossJoinStreamState::BuildBatches(right_data);
465
1
        Poll::Ready(Ok(StatefulStreamResult::Continue))
466
2
    }
467
468
    /// Joins the the indexed row of left data with the current probe batch.
469
    /// If all the results are produced, the state is set to fetch new probe batch.
470
4
    fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
471
4
        let right_batch = self.state.try_as_record_batch()
?0
;
472
4
        if self.left_index < self.left_data.num_rows() {
473
3
            let join_timer = self.join_metrics.join_time.timer();
474
3
            let result =
475
3
                build_batch(self.left_index, right_batch, &self.left_data, &self.schema);
476
3
            join_timer.done();
477
478
3
            if let Ok(ref batch) = result {
479
3
                self.join_metrics.output_batches.add(1);
480
3
                self.join_metrics.output_rows.add(batch.num_rows());
481
3
            }
0
482
3
            self.left_index += 1;
483
3
            result.map(|r| StatefulStreamResult::Ready(Some(r)))
484
        } else {
485
1
            self.state = CrossJoinStreamState::FetchProbeBatch;
486
1
            Ok(StatefulStreamResult::Continue)
487
        }
488
4
    }
489
}
490
491
#[cfg(test)]
492
mod tests {
493
    use super::*;
494
    use crate::common;
495
    use crate::test::build_table_scan_i32;
496
497
    use datafusion_common::{assert_batches_sorted_eq, assert_contains};
498
    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
499
500
2
    async fn join_collect(
501
2
        left: Arc<dyn ExecutionPlan>,
502
2
        right: Arc<dyn ExecutionPlan>,
503
2
        context: Arc<TaskContext>,
504
2
    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
505
2
        let join = CrossJoinExec::new(left, right);
506
2
        let columns_header = columns(&join.schema());
507
508
2
        let stream = join.execute(0, context)
?0
;
509
2
        let 
batches1
= common::collect(stream).
await0
?1
;
510
511
1
        Ok((columns_header, batches))
512
2
    }
513
514
    #[tokio::test]
515
1
    async fn test_stats_cartesian_product() {
516
1
        let left_row_count = 11;
517
1
        let left_bytes = 23;
518
1
        let right_row_count = 7;
519
1
        let right_bytes = 27;
520
1
521
1
        let left = Statistics {
522
1
            num_rows: Precision::Exact(left_row_count),
523
1
            total_byte_size: Precision::Exact(left_bytes),
524
1
            column_statistics: vec![
525
1
                ColumnStatistics {
526
1
                    distinct_count: Precision::Exact(5),
527
1
                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
528
1
                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
529
1
                    null_count: Precision::Exact(0),
530
1
                },
531
1
                ColumnStatistics {
532
1
                    distinct_count: Precision::Exact(1),
533
1
                    max_value: Precision::Exact(ScalarValue::from("x")),
534
1
                    min_value: Precision::Exact(ScalarValue::from("a")),
535
1
                    null_count: Precision::Exact(3),
536
1
                },
537
1
            ],
538
1
        };
539
1
540
1
        let right = Statistics {
541
1
            num_rows: Precision::Exact(right_row_count),
542
1
            total_byte_size: Precision::Exact(right_bytes),
543
1
            column_statistics: vec![ColumnStatistics {
544
1
                distinct_count: Precision::Exact(3),
545
1
                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
546
1
                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
547
1
                null_count: Precision::Exact(2),
548
1
            }],
549
1
        };
550
1
551
1
        let result = stats_cartesian_product(left, right);
552
1
553
1
        let expected = Statistics {
554
1
            num_rows: Precision::Exact(left_row_count * right_row_count),
555
1
            total_byte_size: Precision::Exact(2 * left_bytes * right_bytes),
556
1
            column_statistics: vec![
557
1
                ColumnStatistics {
558
1
                    distinct_count: Precision::Exact(5),
559
1
                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
560
1
                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
561
1
                    null_count: Precision::Exact(0),
562
1
                },
563
1
                ColumnStatistics {
564
1
                    distinct_count: Precision::Exact(1),
565
1
                    max_value: Precision::Exact(ScalarValue::from("x")),
566
1
                    min_value: Precision::Exact(ScalarValue::from("a")),
567
1
                    null_count: Precision::Exact(3 * right_row_count),
568
1
                },
569
1
                ColumnStatistics {
570
1
                    distinct_count: Precision::Exact(3),
571
1
                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
572
1
                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
573
1
                    null_count: Precision::Exact(2 * left_row_count),
574
1
                },
575
1
            ],
576
1
        };
577
1
578
1
        assert_eq!(result, expected);
579
1
    }
580
581
    #[tokio::test]
582
1
    async fn test_stats_cartesian_product_with_unknown_size() {
583
1
        let left_row_count = 11;
584
1
585
1
        let left = Statistics {
586
1
            num_rows: Precision::Exact(left_row_count),
587
1
            total_byte_size: Precision::Exact(23),
588
1
            column_statistics: vec![
589
1
                ColumnStatistics {
590
1
                    distinct_count: Precision::Exact(5),
591
1
                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
592
1
                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
593
1
                    null_count: Precision::Exact(0),
594
1
                },
595
1
                ColumnStatistics {
596
1
                    distinct_count: Precision::Exact(1),
597
1
                    max_value: Precision::Exact(ScalarValue::from("x")),
598
1
                    min_value: Precision::Exact(ScalarValue::from("a")),
599
1
                    null_count: Precision::Exact(3),
600
1
                },
601
1
            ],
602
1
        };
603
1
604
1
        let right = Statistics {
605
1
            num_rows: Precision::Absent,
606
1
            total_byte_size: Precision::Absent,
607
1
            column_statistics: vec![ColumnStatistics {
608
1
                distinct_count: Precision::Exact(3),
609
1
                max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
610
1
                min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
611
1
                null_count: Precision::Exact(2),
612
1
            }],
613
1
        };
614
1
615
1
        let result = stats_cartesian_product(left, right);
616
1
617
1
        let expected = Statistics {
618
1
            num_rows: Precision::Absent,
619
1
            total_byte_size: Precision::Absent,
620
1
            column_statistics: vec![
621
1
                ColumnStatistics {
622
1
                    distinct_count: Precision::Exact(5),
623
1
                    max_value: Precision::Exact(ScalarValue::Int64(Some(21))),
624
1
                    min_value: Precision::Exact(ScalarValue::Int64(Some(-4))),
625
1
                    null_count: Precision::Absent, // we don't know the row count on the right
626
1
                },
627
1
                ColumnStatistics {
628
1
                    distinct_count: Precision::Exact(1),
629
1
                    max_value: Precision::Exact(ScalarValue::from("x")),
630
1
                    min_value: Precision::Exact(ScalarValue::from("a")),
631
1
                    null_count: Precision::Absent, // we don't know the row count on the right
632
1
                },
633
1
                ColumnStatistics {
634
1
                    distinct_count: Precision::Exact(3),
635
1
                    max_value: Precision::Exact(ScalarValue::Int64(Some(12))),
636
1
                    min_value: Precision::Exact(ScalarValue::Int64(Some(0))),
637
1
                    null_count: Precision::Exact(2 * left_row_count),
638
1
                },
639
1
            ],
640
1
        };
641
1
642
1
        assert_eq!(result, expected);
643
1
    }
644
645
    #[tokio::test]
646
1
    async fn test_join() -> Result<()> {
647
1
        let task_ctx = Arc::new(TaskContext::default());
648
1
649
1
        let left = build_table_scan_i32(
650
1
            ("a1", &vec![1, 2, 3]),
651
1
            ("b1", &vec![4, 5, 6]),
652
1
            ("c1", &vec![7, 8, 9]),
653
1
        );
654
1
        let right = build_table_scan_i32(
655
1
            ("a2", &vec![10, 11]),
656
1
            ("b2", &vec![12, 13]),
657
1
            ("c2", &vec![14, 15]),
658
1
        );
659
1
660
1
        let (columns, batches) = join_collect(left, right, task_ctx).
await0
?0
;
661
1
662
1
        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
663
1
        let expected = [
664
1
            "+----+----+----+----+----+----+",
665
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
666
1
            "+----+----+----+----+----+----+",
667
1
            "| 1  | 4  | 7  | 10 | 12 | 14 |",
668
1
            "| 1  | 4  | 7  | 11 | 13 | 15 |",
669
1
            "| 2  | 5  | 8  | 10 | 12 | 14 |",
670
1
            "| 2  | 5  | 8  | 11 | 13 | 15 |",
671
1
            "| 3  | 6  | 9  | 10 | 12 | 14 |",
672
1
            "| 3  | 6  | 9  | 11 | 13 | 15 |",
673
1
            "+----+----+----+----+----+----+",
674
1
        ];
675
1
676
1
        assert_batches_sorted_eq!(expected, &batches);
677
1
678
1
        Ok(())
679
1
    }
680
681
    #[tokio::test]
682
1
    async fn test_overallocation() -> Result<()> {
683
1
        let runtime = RuntimeEnvBuilder::new()
684
1
            .with_memory_limit(100, 1.0)
685
1
            .build_arc()
?0
;
686
1
        let task_ctx = TaskContext::default().with_runtime(runtime);
687
1
        let task_ctx = Arc::new(task_ctx);
688
1
689
1
        let left = build_table_scan_i32(
690
1
            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
691
1
            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
692
1
            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
693
1
        );
694
1
        let right = build_table_scan_i32(
695
1
            ("a2", &vec![10, 11]),
696
1
            ("b2", &vec![12, 13]),
697
1
            ("c2", &vec![14, 15]),
698
1
        );
699
1
700
1
        let err = join_collect(left, right, task_ctx).
await0
.unwrap_err();
701
1
702
1
        assert_contains!(
703
1
            err.to_string(),
704
1
            "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec"
705
1
        );
706
1
707
1
        Ok(())
708
1
    }
709
710
    /// Returns the column names on the schema
711
2
    fn columns(schema: &Schema) -> Vec<String> {
712
12
        schema.fields().iter().map(|f| f.name().clone()).collect()
713
2
    }
714
}