Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/sort_merge_join.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines the Sort-Merge join execution plan.
19
//! A Sort-Merge join plan consumes two sorted children plan and produces
20
//! joined output by given join type and other options.
21
//! Sort-Merge join feature is currently experimental.
22
23
use std::any::Any;
24
use std::cmp::Ordering;
25
use std::collections::{HashMap, VecDeque};
26
use std::fmt::Formatter;
27
use std::fs::File;
28
use std::io::BufReader;
29
use std::mem;
30
use std::ops::Range;
31
use std::pin::Pin;
32
use std::sync::Arc;
33
use std::task::{Context, Poll};
34
35
use arrow::array::*;
36
use arrow::compute::{self, concat_batches, take, SortOptions};
37
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
38
use arrow::error::ArrowError;
39
use arrow::ipc::reader::FileReader;
40
use arrow_array::types::UInt64Type;
41
use futures::{Stream, StreamExt};
42
use hashbrown::HashSet;
43
44
use datafusion_common::{
45
    exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType,
46
    Result,
47
};
48
use datafusion_execution::disk_manager::RefCountedTempFile;
49
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
50
use datafusion_execution::runtime_env::RuntimeEnv;
51
use datafusion_execution::TaskContext;
52
use datafusion_physical_expr::equivalence::join_equivalence_properties;
53
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
54
use datafusion_physical_expr_common::sort_expr::LexRequirement;
55
56
use crate::expressions::PhysicalSortExpr;
57
use crate::joins::utils::{
58
    build_join_schema, check_join_is_valid, estimate_join_statistics,
59
    symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef,
60
};
61
use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
62
use crate::spill::spill_record_batches;
63
use crate::{
64
    execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution,
65
    ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties,
66
    RecordBatchStream, SendableRecordBatchStream, Statistics,
67
};
68
69
/// join execution plan executes partitions in parallel and combines them into a set of
70
/// partitions.
71
#[derive(Debug)]
72
pub struct SortMergeJoinExec {
73
    /// Left sorted joining execution plan
74
    pub left: Arc<dyn ExecutionPlan>,
75
    /// Right sorting joining execution plan
76
    pub right: Arc<dyn ExecutionPlan>,
77
    /// Set of common columns used to join on
78
    pub on: JoinOn,
79
    /// Filters which are applied while finding matching rows
80
    pub filter: Option<JoinFilter>,
81
    /// How the join is performed
82
    pub join_type: JoinType,
83
    /// The schema once the join is applied
84
    schema: SchemaRef,
85
    /// Execution metrics
86
    metrics: ExecutionPlanMetricsSet,
87
    /// The left SortExpr
88
    left_sort_exprs: Vec<PhysicalSortExpr>,
89
    /// The right SortExpr
90
    right_sort_exprs: Vec<PhysicalSortExpr>,
91
    /// Sort options of join columns used in sorting left and right execution plans
92
    pub sort_options: Vec<SortOptions>,
93
    /// If null_equals_null is true, null == null else null != null
94
    pub null_equals_null: bool,
95
    /// Cache holding plan properties like equivalences, output partitioning etc.
96
    cache: PlanProperties,
97
}
98
99
impl SortMergeJoinExec {
100
    /// Tries to create a new [SortMergeJoinExec].
101
    /// The inputs are sorted using `sort_options` are applied to the columns in the `on`
102
    /// # Error
103
    /// This function errors when it is not possible to join the left and right sides on keys `on`.
104
79
    pub fn try_new(
105
79
        left: Arc<dyn ExecutionPlan>,
106
79
        right: Arc<dyn ExecutionPlan>,
107
79
        on: JoinOn,
108
79
        filter: Option<JoinFilter>,
109
79
        join_type: JoinType,
110
79
        sort_options: Vec<SortOptions>,
111
79
        null_equals_null: bool,
112
79
    ) -> Result<Self> {
113
79
        let left_schema = left.schema();
114
79
        let right_schema = right.schema();
115
79
116
79
        if join_type == JoinType::RightSemi {
117
0
            return not_impl_err!(
118
0
                "SortMergeJoinExec does not support JoinType::RightSemi"
119
0
            );
120
79
        }
121
79
122
79
        check_join_is_valid(&left_schema, &right_schema, &on)
?0
;
123
79
        if sort_options.len() != on.len() {
124
0
            return plan_err!(
125
0
                "Expected number of sort options: {}, actual: {}",
126
0
                on.len(),
127
0
                sort_options.len()
128
0
            );
129
79
        }
130
79
131
79
        let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
132
79
            .iter()
133
79
            .zip(sort_options.iter())
134
84
            .map(|((l, r), sort_op)| {
135
84
                let left = PhysicalSortExpr {
136
84
                    expr: Arc::clone(l),
137
84
                    options: *sort_op,
138
84
                };
139
84
                let right = PhysicalSortExpr {
140
84
                    expr: Arc::clone(r),
141
84
                    options: *sort_op,
142
84
                };
143
84
                (left, right)
144
84
            })
145
79
            .unzip();
146
79
147
79
        let schema =
148
79
            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
149
79
        let cache =
150
79
            Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on);
151
79
        Ok(Self {
152
79
            left,
153
79
            right,
154
79
            on,
155
79
            filter,
156
79
            join_type,
157
79
            schema,
158
79
            metrics: ExecutionPlanMetricsSet::new(),
159
79
            left_sort_exprs,
160
79
            right_sort_exprs,
161
79
            sort_options,
162
79
            null_equals_null,
163
79
            cache,
164
79
        })
165
79
    }
166
167
    /// Get probe side (e.g streaming side) information for this sort merge join.
168
    /// In current implementation, probe side is determined according to join type.
169
158
    pub fn probe_side(join_type: &JoinType) -> JoinSide {
170
158
        // When output schema contains only the right side, probe side is right.
171
158
        // Otherwise probe side is the left side.
172
158
        match join_type {
173
            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
174
26
                JoinSide::Right
175
            }
176
            JoinType::Inner
177
            | JoinType::Left
178
            | JoinType::Full
179
            | JoinType::LeftAnti
180
132
            | JoinType::LeftSemi => JoinSide::Left,
181
        }
182
158
    }
183
184
    /// Calculate order preservation flags for this sort merge join.
185
79
    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
186
79
        match join_type {
187
19
            JoinType::Inner => vec![true, false],
188
35
            JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false],
189
            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
190
13
                vec![false, true]
191
            }
192
12
            _ => vec![false, false],
193
        }
194
79
    }
195
196
    /// Set of common columns used to join on
197
0
    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
198
0
        &self.on
199
0
    }
200
201
0
    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
202
0
        &self.right
203
0
    }
204
205
0
    pub fn join_type(&self) -> JoinType {
206
0
        self.join_type
207
0
    }
208
209
0
    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
210
0
        &self.left
211
0
    }
212
213
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
214
79
    fn compute_properties(
215
79
        left: &Arc<dyn ExecutionPlan>,
216
79
        right: &Arc<dyn ExecutionPlan>,
217
79
        schema: SchemaRef,
218
79
        join_type: JoinType,
219
79
        join_on: JoinOnRef,
220
79
    ) -> PlanProperties {
221
79
        // Calculate equivalence properties:
222
79
        let eq_properties = join_equivalence_properties(
223
79
            left.equivalence_properties().clone(),
224
79
            right.equivalence_properties().clone(),
225
79
            &join_type,
226
79
            schema,
227
79
            &Self::maintains_input_order(join_type),
228
79
            Some(Self::probe_side(&join_type)),
229
79
            join_on,
230
79
        );
231
79
232
79
        let output_partitioning =
233
79
            symmetric_join_output_partitioning(left, right, &join_type);
234
79
235
79
        // Determine execution mode:
236
79
        let mode = execution_mode_from_children([left, right]);
237
79
238
79
        PlanProperties::new(eq_properties, output_partitioning, mode)
239
79
    }
240
}
241
242
impl DisplayAs for SortMergeJoinExec {
243
0
    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
244
0
        match t {
245
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
246
0
                let on = self
247
0
                    .on
248
0
                    .iter()
249
0
                    .map(|(c1, c2)| format!("({}, {})", c1, c2))
250
0
                    .collect::<Vec<String>>()
251
0
                    .join(", ");
252
0
                write!(
253
0
                    f,
254
0
                    "SortMergeJoin: join_type={:?}, on=[{}]{}",
255
0
                    self.join_type,
256
0
                    on,
257
0
                    self.filter.as_ref().map_or("".to_string(), |f| format!(
258
0
                        ", filter={}",
259
0
                        f.expression()
260
0
                    ))
261
0
                )
262
0
            }
263
0
        }
264
0
    }
265
}
266
267
impl ExecutionPlan for SortMergeJoinExec {
268
0
    fn name(&self) -> &'static str {
269
0
        "SortMergeJoinExec"
270
0
    }
271
272
0
    fn as_any(&self) -> &dyn Any {
273
0
        self
274
0
    }
275
276
19
    fn properties(&self) -> &PlanProperties {
277
19
        &self.cache
278
19
    }
279
280
0
    fn required_input_distribution(&self) -> Vec<Distribution> {
281
0
        let (left_expr, right_expr) = self
282
0
            .on
283
0
            .iter()
284
0
            .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
285
0
            .unzip();
286
0
        vec![
287
0
            Distribution::HashPartitioned(left_expr),
288
0
            Distribution::HashPartitioned(right_expr),
289
0
        ]
290
0
    }
291
292
0
    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
293
0
        vec![
294
0
            Some(PhysicalSortRequirement::from_sort_exprs(
295
0
                &self.left_sort_exprs,
296
0
            )),
297
0
            Some(PhysicalSortRequirement::from_sort_exprs(
298
0
                &self.right_sort_exprs,
299
0
            )),
300
0
        ]
301
0
    }
302
303
0
    fn maintains_input_order(&self) -> Vec<bool> {
304
0
        Self::maintains_input_order(self.join_type)
305
0
    }
306
307
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
308
0
        vec![&self.left, &self.right]
309
0
    }
310
311
0
    fn with_new_children(
312
0
        self: Arc<Self>,
313
0
        children: Vec<Arc<dyn ExecutionPlan>>,
314
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
315
0
        match &children[..] {
316
0
            [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
317
0
                Arc::clone(left),
318
0
                Arc::clone(right),
319
0
                self.on.clone(),
320
0
                self.filter.clone(),
321
0
                self.join_type,
322
0
                self.sort_options.clone(),
323
0
                self.null_equals_null,
324
0
            )?)),
325
0
            _ => internal_err!("SortMergeJoin wrong number of children"),
326
        }
327
0
    }
328
329
79
    fn execute(
330
79
        &self,
331
79
        partition: usize,
332
79
        context: Arc<TaskContext>,
333
79
    ) -> Result<SendableRecordBatchStream> {
334
79
        let left_partitions = self.left.output_partitioning().partition_count();
335
79
        let right_partitions = self.right.output_partitioning().partition_count();
336
79
        if left_partitions != right_partitions {
337
0
            return internal_err!(
338
0
                "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
339
0
                 consider using RepartitionExec"
340
0
            );
341
79
        }
342
79
        let (on_left, on_right) = self.on.iter().cloned().unzip();
343
79
        let (streamed, buffered, on_streamed, on_buffered) =
344
79
            if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
345
66
                (
346
66
                    Arc::clone(&self.left),
347
66
                    Arc::clone(&self.right),
348
66
                    on_left,
349
66
                    on_right,
350
66
                )
351
            } else {
352
13
                (
353
13
                    Arc::clone(&self.right),
354
13
                    Arc::clone(&self.left),
355
13
                    on_right,
356
13
                    on_left,
357
13
                )
358
            };
359
360
        // execute children plans
361
79
        let streamed = streamed.execute(partition, Arc::clone(&context))
?0
;
362
79
        let buffered = buffered.execute(partition, Arc::clone(&context))
?0
;
363
364
        // create output buffer
365
79
        let batch_size = context.session_config().batch_size();
366
79
367
79
        // create memory reservation
368
79
        let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
369
79
            .register(context.memory_pool());
370
79
371
79
        // create join stream
372
79
        Ok(Box::pin(SMJStream::try_new(
373
79
            Arc::clone(&self.schema),
374
79
            self.sort_options.clone(),
375
79
            self.null_equals_null,
376
79
            streamed,
377
79
            buffered,
378
79
            on_streamed,
379
79
            on_buffered,
380
79
            self.filter.clone(),
381
79
            self.join_type,
382
79
            batch_size,
383
79
            SortMergeJoinMetrics::new(partition, &self.metrics),
384
79
            reservation,
385
79
            context.runtime_env(),
386
79
        )
?0
))
387
79
    }
388
389
240
    fn metrics(&self) -> Option<MetricsSet> {
390
240
        Some(self.metrics.clone_inner())
391
240
    }
392
393
0
    fn statistics(&self) -> Result<Statistics> {
394
0
        // TODO stats: it is not possible in general to know the output size of joins
395
0
        // There are some special cases though, for example:
396
0
        // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)`
397
0
        estimate_join_statistics(
398
0
            Arc::clone(&self.left),
399
0
            Arc::clone(&self.right),
400
0
            self.on.clone(),
401
0
            &self.join_type,
402
0
            &self.schema,
403
0
        )
404
0
    }
405
}
406
407
/// Metrics for SortMergeJoinExec
408
#[allow(dead_code)]
409
struct SortMergeJoinMetrics {
410
    /// Total time for joining probe-side batches to the build-side batches
411
    join_time: metrics::Time,
412
    /// Number of batches consumed by this operator
413
    input_batches: metrics::Count,
414
    /// Number of rows consumed by this operator
415
    input_rows: metrics::Count,
416
    /// Number of batches produced by this operator
417
    output_batches: metrics::Count,
418
    /// Number of rows produced by this operator
419
    output_rows: metrics::Count,
420
    /// Peak memory used for buffered data.
421
    /// Calculated as sum of peak memory values across partitions
422
    peak_mem_used: metrics::Gauge,
423
    /// count of spills during the execution of the operator
424
    spill_count: Count,
425
    /// total spilled bytes during the execution of the operator
426
    spilled_bytes: Count,
427
    /// total spilled rows during the execution of the operator
428
    spilled_rows: Count,
429
}
430
431
impl SortMergeJoinMetrics {
432
    #[allow(dead_code)]
433
79
    pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
434
79
        let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
435
79
        let input_batches =
436
79
            MetricBuilder::new(metrics).counter("input_batches", partition);
437
79
        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
438
79
        let output_batches =
439
79
            MetricBuilder::new(metrics).counter("output_batches", partition);
440
79
        let output_rows = MetricBuilder::new(metrics).output_rows(partition);
441
79
        let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition);
442
79
        let spill_count = MetricBuilder::new(metrics).spill_count(partition);
443
79
        let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition);
444
79
        let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
445
79
446
79
        Self {
447
79
            join_time,
448
79
            input_batches,
449
79
            input_rows,
450
79
            output_batches,
451
79
            output_rows,
452
79
            peak_mem_used,
453
79
            spill_count,
454
79
            spilled_bytes,
455
79
            spilled_rows,
456
79
        }
457
79
    }
458
}
459
460
/// State of SMJ stream
461
#[derive(Debug, PartialEq, Eq)]
462
enum SMJState {
463
    /// Init joining with a new streamed row or a new buffered batches
464
    Init,
465
    /// Polling one streamed row or one buffered batch, or both
466
    Polling,
467
    /// Joining polled data and making output
468
    JoinOutput,
469
    /// No more output
470
    Exhausted,
471
}
472
473
/// State of streamed data stream
474
#[derive(Debug, PartialEq, Eq)]
475
enum StreamedState {
476
    /// Init polling
477
    Init,
478
    /// Polling one streamed row
479
    Polling,
480
    /// Ready to produce one streamed row
481
    Ready,
482
    /// No more streamed row
483
    Exhausted,
484
}
485
486
/// State of buffered data stream
487
#[derive(Debug, PartialEq, Eq)]
488
enum BufferedState {
489
    /// Init polling
490
    Init,
491
    /// Polling first row in the next batch
492
    PollingFirst,
493
    /// Polling rest rows in the next batch
494
    PollingRest,
495
    /// Ready to produce one batch
496
    Ready,
497
    /// No more buffered batches
498
    Exhausted,
499
}
500
501
/// Represents a chunk of joined data from streamed and buffered side
502
struct StreamedJoinedChunk {
503
    /// Index of batch in buffered_data
504
    buffered_batch_idx: Option<usize>,
505
    /// Array builder for streamed indices
506
    streamed_indices: UInt64Builder,
507
    /// Array builder for buffered indices
508
    /// This could contain nulls if the join is null-joined
509
    buffered_indices: UInt64Builder,
510
}
511
512
struct StreamedBatch {
513
    /// The streamed record batch
514
    pub batch: RecordBatch,
515
    /// The index of row in the streamed batch to compare with buffered batches
516
    pub idx: usize,
517
    /// The join key arrays of streamed batch which are used to compare with buffered batches
518
    /// and to produce output. They are produced by evaluating `on` expressions.
519
    pub join_arrays: Vec<ArrayRef>,
520
    /// Chunks of indices from buffered side (may be nulls) joined to streamed
521
    pub output_indices: Vec<StreamedJoinedChunk>,
522
    /// Index of currently scanned batch from buffered data
523
    pub buffered_batch_idx: Option<usize>,
524
    /// Indices that found a match for the given join filter
525
    /// Used for semi joins to keep track the streaming index which got a join filter match
526
    /// and already emitted to the output.
527
    pub join_filter_matched_idxs: HashSet<u64>,
528
}
529
530
impl StreamedBatch {
531
130
    fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
532
130
        let join_arrays = join_arrays(&batch, on_column);
533
130
        StreamedBatch {
534
130
            batch,
535
130
            idx: 0,
536
130
            join_arrays,
537
130
            output_indices: vec![],
538
130
            buffered_batch_idx: None,
539
130
            join_filter_matched_idxs: HashSet::new(),
540
130
        }
541
130
    }
542
543
79
    fn new_empty(schema: SchemaRef) -> Self {
544
79
        StreamedBatch {
545
79
            batch: RecordBatch::new_empty(schema),
546
79
            idx: 0,
547
79
            join_arrays: vec![],
548
79
            output_indices: vec![],
549
79
            buffered_batch_idx: None,
550
79
            join_filter_matched_idxs: HashSet::new(),
551
79
        }
552
79
    }
553
554
    /// Appends new pair consisting of current streamed index and `buffered_idx`
555
    /// index of buffered batch with `buffered_batch_idx` index.
556
693
    fn append_output_pair(
557
693
        &mut self,
558
693
        buffered_batch_idx: Option<usize>,
559
693
        buffered_idx: Option<usize>,
560
693
    ) {
561
693
        // If no current chunk exists or current chunk is not for current buffered batch,
562
693
        // create a new chunk
563
693
        if self.output_indices.is_empty() || 
self.buffered_batch_idx != buffered_batch_idx312
564
504
        {
565
504
            self.output_indices.push(StreamedJoinedChunk {
566
504
                buffered_batch_idx,
567
504
                streamed_indices: UInt64Builder::with_capacity(1),
568
504
                buffered_indices: UInt64Builder::with_capacity(1),
569
504
            });
570
504
            self.buffered_batch_idx = buffered_batch_idx;
571
504
        }
;189
572
693
        let current_chunk = self.output_indices.last_mut().unwrap();
573
693
574
693
        // Append index of streamed batch and index of buffered batch into current chunk
575
693
        current_chunk.streamed_indices.append_value(self.idx as u64);
576
693
        if let Some(
idx600
) = buffered_idx {
577
600
            current_chunk.buffered_indices.append_value(idx as u64);
578
600
        } else {
579
93
            current_chunk.buffered_indices.append_null();
580
93
        }
581
693
    }
582
}
583
584
/// A buffered batch that contains contiguous rows with same join key
585
#[derive(Debug)]
586
struct BufferedBatch {
587
    /// The buffered record batch
588
    /// None if the batch spilled to disk th
589
    pub batch: Option<RecordBatch>,
590
    /// The range in which the rows share the same join key
591
    pub range: Range<usize>,
592
    /// Array refs of the join key
593
    pub join_arrays: Vec<ArrayRef>,
594
    /// Buffered joined index (null joining buffered)
595
    pub null_joined: Vec<usize>,
596
    /// Size estimation used for reserving / releasing memory
597
    pub size_estimation: usize,
598
    /// The indices of buffered batch that failed the join filter.
599
    /// This is a map between buffered row index and a boolean value indicating whether all joined row
600
    /// of the buffered row failed the join filter.
601
    /// When dequeuing the buffered batch, we need to produce null joined rows for these indices.
602
    pub join_filter_failed_map: HashMap<u64, bool>,
603
    /// Current buffered batch number of rows. Equal to batch.num_rows()
604
    /// but if batch is spilled to disk this property is preferable
605
    /// and less expensive
606
    pub num_rows: usize,
607
    /// An optional temp spill file name on the disk if the batch spilled
608
    /// None by default
609
    /// Some(fileName) if the batch spilled to the disk
610
    pub spill_file: Option<RefCountedTempFile>,
611
}
612
613
impl BufferedBatch {
614
130
    fn new(
615
130
        batch: RecordBatch,
616
130
        range: Range<usize>,
617
130
        on_column: &[PhysicalExprRef],
618
130
    ) -> Self {
619
130
        let join_arrays = join_arrays(&batch, on_column);
620
130
621
130
        // Estimation is calculated as
622
130
        //   inner batch size
623
130
        // + join keys size
624
130
        // + worst case null_joined (as vector capacity * element size)
625
130
        // + Range size
626
130
        // + size of this estimation
627
130
        let size_estimation = batch.get_array_memory_size()
628
130
            + join_arrays
629
130
                .iter()
630
135
                .map(|arr| arr.get_array_memory_size())
631
130
                .sum::<usize>()
632
130
            + batch.num_rows().next_power_of_two() * mem::size_of::<usize>()
633
130
            + mem::size_of::<Range<usize>>()
634
130
            + mem::size_of::<usize>();
635
130
636
130
        let num_rows = batch.num_rows();
637
130
        BufferedBatch {
638
130
            batch: Some(batch),
639
130
            range,
640
130
            join_arrays,
641
130
            null_joined: vec![],
642
130
            size_estimation,
643
130
            join_filter_failed_map: HashMap::new(),
644
130
            num_rows,
645
130
            spill_file: None,
646
130
        }
647
130
    }
648
}
649
650
/// Sort-merge join stream that consumes streamed and buffered data stream
651
/// and produces joined output
652
struct SMJStream {
653
    /// Current state of the stream
654
    pub state: SMJState,
655
    /// Output schema
656
    pub schema: SchemaRef,
657
    /// Sort options of join columns used to sort streamed and buffered data stream
658
    pub sort_options: Vec<SortOptions>,
659
    /// null == null?
660
    pub null_equals_null: bool,
661
    /// Input schema of streamed
662
    pub streamed_schema: SchemaRef,
663
    /// Input schema of buffered
664
    pub buffered_schema: SchemaRef,
665
    /// Streamed data stream
666
    pub streamed: SendableRecordBatchStream,
667
    /// Buffered data stream
668
    pub buffered: SendableRecordBatchStream,
669
    /// Current processing record batch of streamed
670
    pub streamed_batch: StreamedBatch,
671
    /// Current buffered data
672
    pub buffered_data: BufferedData,
673
    /// (used in outer join) Is current streamed row joined at least once?
674
    pub streamed_joined: bool,
675
    /// (used in outer join) Is current buffered batches joined at least once?
676
    pub buffered_joined: bool,
677
    /// State of streamed
678
    pub streamed_state: StreamedState,
679
    /// State of buffered
680
    pub buffered_state: BufferedState,
681
    /// The comparison result of current streamed row and buffered batches
682
    pub current_ordering: Ordering,
683
    /// Join key columns of streamed
684
    pub on_streamed: Vec<PhysicalExprRef>,
685
    /// Join key columns of buffered
686
    pub on_buffered: Vec<PhysicalExprRef>,
687
    /// optional join filter
688
    pub filter: Option<JoinFilter>,
689
    /// Staging output array builders
690
    pub output_record_batches: Vec<RecordBatch>,
691
    /// Staging output size, including output batches and staging joined results.
692
    /// Increased when we put rows into buffer and decreased after we actually output batches.
693
    /// Used to trigger output when sufficient rows are ready
694
    pub output_size: usize,
695
    /// Target output batch size
696
    pub batch_size: usize,
697
    /// How the join is performed
698
    pub join_type: JoinType,
699
    /// Metrics
700
    pub join_metrics: SortMergeJoinMetrics,
701
    /// Memory reservation
702
    pub reservation: MemoryReservation,
703
    /// Runtime env
704
    pub runtime_env: Arc<RuntimeEnv>,
705
}
706
707
impl RecordBatchStream for SMJStream {
708
0
    fn schema(&self) -> SchemaRef {
709
0
        Arc::clone(&self.schema)
710
0
    }
711
}
712
713
impl Stream for SMJStream {
714
    type Item = Result<RecordBatch>;
715
716
429
    fn poll_next(
717
429
        mut self: Pin<&mut Self>,
718
429
        cx: &mut Context<'_>,
719
429
    ) -> Poll<Option<Self::Item>> {
720
429
        let join_time = self.join_metrics.join_time.clone();
721
429
        let _timer = join_time.timer();
722
723
        loop {
724
2.28k
            match &self.state {
725
                SMJState::Init => {
726
648
                    let streamed_exhausted =
727
648
                        self.streamed_state == StreamedState::Exhausted;
728
648
                    let buffered_exhausted =
729
648
                        self.buffered_state == BufferedState::Exhausted;
730
648
                    self.state = if streamed_exhausted && 
buffered_exhausted81
{
731
0
                        SMJState::Exhausted
732
                    } else {
733
648
                        match self.current_ordering {
734
                            Ordering::Less | Ordering::Equal => {
735
437
                                if !streamed_exhausted {
736
437
                                    self.streamed_joined = false;
737
437
                                    self.streamed_state = StreamedState::Init;
738
437
                                }
0
739
                            }
740
                            Ordering::Greater => {
741
211
                                if !buffered_exhausted {
742
211
                                    self.buffered_joined = false;
743
211
                                    self.buffered_state = BufferedState::Init;
744
211
                                }
0
745
                            }
746
                        }
747
648
                        SMJState::Polling
748
                    };
749
                }
750
                SMJState::Polling => {
751
648
                    if ![StreamedState::Exhausted, StreamedState::Ready]
752
648
                        .contains(&self.streamed_state)
753
                    {
754
437
                        match self.poll_streamed_row(cx)
?0
{
755
437
                            Poll::Ready(_) => {}
756
0
                            Poll::Pending => return Poll::Pending,
757
                        }
758
211
                    }
759
760
648
                    if ![BufferedState::Exhausted, BufferedState::Ready]
761
648
                        .contains(&self.buffered_state)
762
                    {
763
290
                        match self.poll_buffered_batches(cx)
?12
{
764
278
                            Poll::Ready(_) => {}
765
0
                            Poll::Pending => return Poll::Pending,
766
                        }
767
358
                    }
768
636
                    let streamed_exhausted =
769
636
                        self.streamed_state == StreamedState::Exhausted;
770
636
                    let buffered_exhausted =
771
636
                        self.buffered_state == BufferedState::Exhausted;
772
636
                    if streamed_exhausted && 
buffered_exhausted148
{
773
67
                        self.state = SMJState::Exhausted;
774
67
                        continue;
775
569
                    }
776
569
                    self.current_ordering = self.compare_streamed_buffered()
?0
;
777
569
                    self.state = SMJState::JoinOutput;
778
                }
779
                SMJState::JoinOutput => {
780
878
                    self.join_partial()
?0
;
781
782
878
                    if self.output_size < self.batch_size {
783
569
                        if self.buffered_data.scanning_finished() {
784
569
                            self.buffered_data.scanning_reset();
785
569
                            self.state = SMJState::Init;
786
569
                        }
0
787
                    } else {
788
309
                        self.freeze_all()
?0
;
789
309
                        if !self.output_record_batches.is_empty() {
790
309
                            let record_batch = self.output_record_batch_and_reset()
?0
;
791
309
                            return Poll::Ready(Some(Ok(record_batch)));
792
0
                        }
793
0
                        return Poll::Pending;
794
                    }
795
                }
796
                SMJState::Exhausted => {
797
108
                    self.freeze_all()
?0
;
798
108
                    if !self.output_record_batches.is_empty() {
799
41
                        let record_batch = self.output_record_batch_and_reset()
?0
;
800
41
                        return Poll::Ready(Some(Ok(record_batch)));
801
67
                    }
802
67
                    return Poll::Ready(None);
803
                }
804
            }
805
        }
806
429
    }
807
}
808
809
impl SMJStream {
810
    #[allow(clippy::too_many_arguments)]
811
79
    pub fn try_new(
812
79
        schema: SchemaRef,
813
79
        sort_options: Vec<SortOptions>,
814
79
        null_equals_null: bool,
815
79
        streamed: SendableRecordBatchStream,
816
79
        buffered: SendableRecordBatchStream,
817
79
        on_streamed: Vec<Arc<dyn PhysicalExpr>>,
818
79
        on_buffered: Vec<Arc<dyn PhysicalExpr>>,
819
79
        filter: Option<JoinFilter>,
820
79
        join_type: JoinType,
821
79
        batch_size: usize,
822
79
        join_metrics: SortMergeJoinMetrics,
823
79
        reservation: MemoryReservation,
824
79
        runtime_env: Arc<RuntimeEnv>,
825
79
    ) -> Result<Self> {
826
79
        let streamed_schema = streamed.schema();
827
79
        let buffered_schema = buffered.schema();
828
79
        Ok(Self {
829
79
            state: SMJState::Init,
830
79
            sort_options,
831
79
            null_equals_null,
832
79
            schema,
833
79
            streamed_schema: Arc::clone(&streamed_schema),
834
79
            buffered_schema,
835
79
            streamed,
836
79
            buffered,
837
79
            streamed_batch: StreamedBatch::new_empty(streamed_schema),
838
79
            buffered_data: BufferedData::default(),
839
79
            streamed_joined: false,
840
79
            buffered_joined: false,
841
79
            streamed_state: StreamedState::Init,
842
79
            buffered_state: BufferedState::Init,
843
79
            current_ordering: Ordering::Equal,
844
79
            on_streamed,
845
79
            on_buffered,
846
79
            filter,
847
79
            output_record_batches: vec![],
848
79
            output_size: 0,
849
79
            batch_size,
850
79
            join_type,
851
79
            join_metrics,
852
79
            reservation,
853
79
            runtime_env,
854
79
        })
855
79
    }
856
857
    /// Poll next streamed row
858
437
    fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
859
        loop {
860
831
            match &self.streamed_state {
861
                StreamedState::Init => {
862
437
                    if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
863
                    {
864
240
                        self.streamed_batch.idx += 1;
865
240
                        self.streamed_state = StreamedState::Ready;
866
240
                        return Poll::Ready(Some(Ok(())));
867
197
                    } else {
868
197
                        self.streamed_state = StreamedState::Polling;
869
197
                    }
870
                }
871
197
                StreamedState::Polling => match self.streamed.poll_next_unpin(cx)
?0
{
872
                    Poll::Pending => {
873
0
                        return Poll::Pending;
874
                    }
875
67
                    Poll::Ready(None) => {
876
67
                        self.streamed_state = StreamedState::Exhausted;
877
67
                    }
878
130
                    Poll::Ready(Some(batch)) => {
879
130
                        if batch.num_rows() > 0 {
880
130
                            self.freeze_streamed()
?0
;
881
130
                            self.join_metrics.input_batches.add(1);
882
130
                            self.join_metrics.input_rows.add(batch.num_rows());
883
130
                            self.streamed_batch =
884
130
                                StreamedBatch::new(batch, &self.on_streamed);
885
130
                            self.streamed_state = StreamedState::Ready;
886
0
                        }
887
                    }
888
                },
889
                StreamedState::Ready => {
890
130
                    return Poll::Ready(Some(Ok(())));
891
                }
892
                StreamedState::Exhausted => {
893
67
                    return Poll::Ready(None);
894
                }
895
            }
896
        }
897
437
    }
898
899
118
    fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> {
900
118
        // Shrink memory usage for in-memory batches only
901
118
        if buffered_batch.spill_file.is_none() && 
buffered_batch.batch.is_some()82
{
902
82
            self.reservation
903
82
                .try_shrink(buffered_batch.size_estimation)
?0
;
904
36
        }
905
906
118
        Ok(())
907
118
    }
908
909
130
    fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
910
130
        match self.reservation.try_grow(buffered_batch.size_estimation) {
911
            Ok(_) => {
912
82
                self.join_metrics
913
82
                    .peak_mem_used
914
82
                    .set_max(self.reservation.size());
915
82
                Ok(())
916
            }
917
48
            Err(_) if self.runtime_env.disk_manager.tmp_files_enabled(
)36
=> {
918
                // spill buffered batch to disk
919
36
                let spill_file = self
920
36
                    .runtime_env
921
36
                    .disk_manager
922
36
                    .create_tmp_file("sort_merge_join_buffered_spill")
?0
;
923
924
36
                if let Some(batch) = buffered_batch.batch {
925
36
                    spill_record_batches(
926
36
                        vec![batch],
927
36
                        spill_file.path().into(),
928
36
                        Arc::clone(&self.buffered_schema),
929
36
                    )
?0
;
930
36
                    buffered_batch.spill_file = Some(spill_file);
931
36
                    buffered_batch.batch = None;
932
36
933
36
                    // update metrics to register spill
934
36
                    self.join_metrics.spill_count.add(1);
935
36
                    self.join_metrics
936
36
                        .spilled_bytes
937
36
                        .add(buffered_batch.size_estimation);
938
36
                    self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
939
36
                    Ok(())
940
                } else {
941
0
                    internal_err!("Buffered batch has empty body")
942
                }
943
            }
944
12
            Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
945
12
        }?;
946
947
118
        self.buffered_data.batches.push_back(buffered_batch);
948
118
        Ok(())
949
130
    }
950
951
    /// Poll next buffered batches
952
290
    fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
953
        loop {
954
837
            match &self.buffered_state {
955
                BufferedState::Init => {
956
                    // pop previous buffered batches
957
408
                    while !self.buffered_data.batches.is_empty() {
958
262
                        let head_batch = self.buffered_data.head_batch();
959
262
                        // If the head batch is fully processed, dequeue it and produce output of it.
960
262
                        if head_batch.range.end == head_batch.num_rows {
961
118
                            self.freeze_dequeuing_buffered()
?0
;
962
118
                            if let Some(buffered_batch) =
963
118
                                self.buffered_data.batches.pop_front()
964
                            {
965
118
                                self.free_reservation(buffered_batch)
?0
;
966
0
                            }
967
                        } else {
968
                            // If the head batch is not fully processed, break the loop.
969
                            // Streamed batch will be joined with the head batch in the next step.
970
144
                            break;
971
                        }
972
                    }
973
290
                    if self.buffered_data.batches.is_empty() {
974
146
                        self.buffered_state = BufferedState::PollingFirst;
975
146
                    } else {
976
144
                        let tail_batch = self.buffered_data.tail_batch_mut();
977
144
                        tail_batch.range.start = tail_batch.range.end;
978
144
                        tail_batch.range.end += 1;
979
144
                        self.buffered_state = BufferedState::PollingRest;
980
144
                    }
981
                }
982
146
                BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)
?0
{
983
                    Poll::Pending => {
984
0
                        return Poll::Pending;
985
                    }
986
                    Poll::Ready(None) => {
987
67
                        self.buffered_state = BufferedState::Exhausted;
988
67
                        return Poll::Ready(None);
989
                    }
990
79
                    Poll::Ready(Some(batch)) => {
991
79
                        self.join_metrics.input_batches.add(1);
992
79
                        self.join_metrics.input_rows.add(batch.num_rows());
993
79
994
79
                        if batch.num_rows() > 0 {
995
79
                            let buffered_batch =
996
79
                                BufferedBatch::new(batch, 0..1, &self.on_buffered);
997
79
998
79
                            self.allocate_reservation(buffered_batch)
?12
;
999
67
                            self.buffered_state = BufferedState::PollingRest;
1000
0
                        }
1001
                    }
1002
                },
1003
                BufferedState::PollingRest => {
1004
334
                    if self.buffered_data.tail_batch().range.end
1005
334
                        < self.buffered_data.tail_batch().num_rows
1006
                    {
1007
321
                        while self.buffered_data.tail_batch().range.end
1008
321
                            < self.buffered_data.tail_batch().num_rows
1009
                        {
1010
249
                            if is_join_arrays_equal(
1011
249
                                &self.buffered_data.head_batch().join_arrays,
1012
249
                                self.buffered_data.head_batch().range.start,
1013
249
                                &self.buffered_data.tail_batch().join_arrays,
1014
249
                                self.buffered_data.tail_batch().range.end,
1015
249
                            )
?0
{
1016
105
                                self.buffered_data.tail_batch_mut().range.end += 1;
1017
105
                            } else {
1018
144
                                self.buffered_state = BufferedState::Ready;
1019
144
                                return Poll::Ready(Some(Ok(())));
1020
                            }
1021
                        }
1022
                    } else {
1023
118
                        match self.buffered.poll_next_unpin(cx)
?0
{
1024
                            Poll::Pending => {
1025
0
                                return Poll::Pending;
1026
                            }
1027
67
                            Poll::Ready(None) => {
1028
67
                                self.buffered_state = BufferedState::Ready;
1029
67
                            }
1030
51
                            Poll::Ready(Some(batch)) => {
1031
51
                                // Polling batches coming concurrently as multiple partitions
1032
51
                                self.join_metrics.input_batches.add(1);
1033
51
                                self.join_metrics.input_rows.add(batch.num_rows());
1034
51
                                if batch.num_rows() > 0 {
1035
51
                                    let buffered_batch = BufferedBatch::new(
1036
51
                                        batch,
1037
51
                                        0..0,
1038
51
                                        &self.on_buffered,
1039
51
                                    );
1040
51
                                    self.allocate_reservation(buffered_batch)
?0
;
1041
0
                                }
1042
                            }
1043
                        }
1044
                    }
1045
                }
1046
                BufferedState::Ready => {
1047
67
                    return Poll::Ready(Some(Ok(())));
1048
                }
1049
                BufferedState::Exhausted => {
1050
0
                    return Poll::Ready(None);
1051
                }
1052
            }
1053
        }
1054
290
    }
1055
1056
    /// Get comparison result of streamed row and buffered batches
1057
569
    fn compare_streamed_buffered(&self) -> Result<Ordering> {
1058
569
        if self.streamed_state == StreamedState::Exhausted {
1059
81
            return Ok(Ordering::Greater);
1060
488
        }
1061
488
        if !self.buffered_data.has_buffered_rows() {
1062
12
            return Ok(Ordering::Less);
1063
476
        }
1064
476
1065
476
        return compare_join_arrays(
1066
476
            &self.streamed_batch.join_arrays,
1067
476
            self.streamed_batch.idx,
1068
476
            &self.buffered_data.head_batch().join_arrays,
1069
476
            self.buffered_data.head_batch().range.start,
1070
476
            &self.sort_options,
1071
476
            self.null_equals_null,
1072
476
        );
1073
569
    }
1074
1075
    /// Produce join and fill output buffer until reaching target batch size
1076
    /// or the join is finished
1077
878
    fn join_partial(&mut self) -> Result<()> {
1078
878
        // Whether to join streamed rows
1079
878
        let mut join_streamed = false;
1080
878
        // Whether to join buffered rows
1081
878
        let mut join_buffered = false;
1082
878
1083
878
        // determine whether we need to join streamed/buffered rows
1084
878
        match self.current_ordering {
1085
            Ordering::Less => {
1086
20
                if matches!(
1087
84
                    self.join_type,
1088
                    JoinType::Left
1089
                        | JoinType::Right
1090
                        | JoinType::RightSemi
1091
                        | JoinType::Full
1092
                        | JoinType::LeftAnti
1093
64
                ) {
1094
64
                    join_streamed = !self.streamed_joined;
1095
64
                }
20
1096
            }
1097
            Ordering::Equal => {
1098
581
                if 
matches!518
(self.join_type, JoinType::LeftSemi) {
1099
                    // if the join filter is specified then its needed to output the streamed index
1100
                    // only if it has not been emitted before
1101
                    // the `join_filter_matched_idxs` keeps track on if streamed index has a successful
1102
                    // filter match and prevents the same index to go into output more than once
1103
63
                    if self.filter.is_some() {
1104
0
                        join_streamed = !self
1105
0
                            .streamed_batch
1106
0
                            .join_filter_matched_idxs
1107
0
                            .contains(&(self.streamed_batch.idx as u64))
1108
0
                            && !self.streamed_joined;
1109
                        // if the join filter specified there can be references to buffered columns
1110
                        // so buffered columns are needed to access them
1111
0
                        join_buffered = join_streamed;
1112
63
                    } else {
1113
63
                        join_streamed = !self.streamed_joined;
1114
63
                    }
1115
518
                }
1116
106
                if matches!(
1117
581
                    self.join_type,
1118
                    JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
1119
475
                ) {
1120
475
                    join_streamed = true;
1121
475
                    join_buffered = true;
1122
475
                }
;106
1123
1124
581
                if 
matches!538
(self.join_type, JoinType::LeftAnti) &&
self.filter.is_some()43
{
1125
0
                    join_streamed = !self
1126
0
                        .streamed_batch
1127
0
                        .join_filter_matched_idxs
1128
0
                        .contains(&(self.streamed_batch.idx as u64))
1129
0
                        && !self.streamed_joined;
1130
0
                    join_buffered = join_streamed;
1131
581
                }
1132
            }
1133
            Ordering::Greater => {
1134
213
                if 
matches!180
(self.join_type, JoinType::Full) {
1135
33
                    join_buffered = !self.buffered_joined;
1136
180
                };
1137
            }
1138
        }
1139
878
        if !join_streamed && 
!join_buffered310
{
1140
            // no joined data
1141
303
            self.buffered_data.scanning_finish();
1142
303
            return Ok(());
1143
575
        }
1144
575
1145
575
        if join_buffered {
1146
            // joining streamed/nulls and buffered
1147
1.08k
            while !self.buffered_data.scanning_finished()
1148
801
                && self.output_size < self.batch_size
1149
            {
1150
607
                let scanning_idx = self.buffered_data.scanning_idx();
1151
607
                if join_streamed {
1152
600
                    // Join streamed row and buffered row
1153
600
                    self.streamed_batch.append_output_pair(
1154
600
                        Some(self.buffered_data.scanning_batch_idx),
1155
600
                        Some(scanning_idx),
1156
600
                    );
1157
600
                } else {
1158
7
                    // Join nulls and buffered row for FULL join
1159
7
                    self.buffered_data
1160
7
                        .scanning_batch_mut()
1161
7
                        .null_joined
1162
7
                        .push(scanning_idx);
1163
7
                }
1164
607
                self.output_size += 1;
1165
607
                self.buffered_data.scanning_advance();
1166
607
1167
607
                if self.buffered_data.scanning_finished() {
1168
209
                    self.streamed_joined = join_streamed;
1169
209
                    self.buffered_joined = true;
1170
398
                }
1171
            }
1172
        } else {
1173
            // joining streamed and nulls
1174
93
            let scanning_batch_idx = if self.buffered_data.scanning_finished() {
1175
11
                None
1176
            } else {
1177
82
                Some(self.buffered_data.scanning_batch_idx)
1178
            };
1179
1180
93
            self.streamed_batch
1181
93
                .append_output_pair(scanning_batch_idx, None);
1182
93
            self.output_size += 1;
1183
93
            self.buffered_data.scanning_finish();
1184
93
            self.streamed_joined = true;
1185
        }
1186
575
        Ok(())
1187
878
    }
1188
1189
417
    fn freeze_all(&mut self) -> Result<()> {
1190
417
        self.freeze_streamed()
?0
;
1191
417
        self.freeze_buffered(self.buffered_data.batches.len(), false)
?0
;
1192
417
        Ok(())
1193
417
    }
1194
1195
    // Produces and stages record batches to ensure dequeued buffered batch
1196
    // no longer needed:
1197
    //   1. freezes all indices joined to streamed side
1198
    //   2. freezes NULLs joined to dequeued buffered batch to "release" it
1199
118
    fn freeze_dequeuing_buffered(&mut self) -> Result<()> {
1200
118
        self.freeze_streamed()
?0
;
1201
        // Only freeze and produce the first batch in buffered_data as the batch is fully processed
1202
118
        self.freeze_buffered(1, true)
?0
;
1203
118
        Ok(())
1204
118
    }
1205
1206
    // Produces and stages record batch from buffered indices with corresponding
1207
    // NULLs on streamed side.
1208
    //
1209
    // Applicable only in case of Full join.
1210
    //
1211
    // If `output_not_matched_filter` is true, this will also produce record batches
1212
    // for buffered rows which are joined with streamed side but don't match join filter.
1213
535
    fn freeze_buffered(
1214
535
        &mut self,
1215
535
        batch_count: usize,
1216
535
        output_not_matched_filter: bool,
1217
535
    ) -> Result<()> {
1218
535
        if !
matches!426
(self.join_type, JoinType::Full) {
1219
426
            return Ok(());
1220
109
        }
1221
213
        for buffered_batch in 
self.buffered_data.batches.range_mut(..batch_count)109
{
1222
213
            let buffered_indices = UInt64Array::from_iter_values(
1223
213
                buffered_batch.null_joined.iter().map(|&index| 
index as u647
),
1224
213
            );
1225
213
            if let Some(
record_batch7
) = produce_buffered_null_batch(
1226
213
                &self.schema,
1227
213
                &self.streamed_schema,
1228
213
                &buffered_indices,
1229
213
                buffered_batch,
1230
213
            )
?0
{
1231
7
                self.output_record_batches.push(record_batch);
1232
206
            }
1233
213
            buffered_batch.null_joined.clear();
1234
213
1235
213
            // For buffered row which is joined with streamed side rows but all joined rows
1236
213
            // don't satisfy the join filter
1237
213
            if output_not_matched_filter {
1238
19
                let not_matched_buffered_indices = buffered_batch
1239
19
                    .join_filter_failed_map
1240
19
                    .iter()
1241
19
                    .filter_map(|(idx, failed)| 
if *failed0
{
Some(*idx)0
} else {
None0
}0
)
1242
19
                    .collect::<Vec<_>>();
1243
19
1244
19
                let buffered_indices = UInt64Array::from_iter_values(
1245
19
                    not_matched_buffered_indices.iter().copied(),
1246
19
                );
1247
1248
19
                if let Some(
record_batch0
) = produce_buffered_null_batch(
1249
19
                    &self.schema,
1250
19
                    &self.streamed_schema,
1251
19
                    &buffered_indices,
1252
19
                    buffered_batch,
1253
19
                )
?0
{
1254
0
                    self.output_record_batches.push(record_batch);
1255
19
                }
1256
19
                buffered_batch.join_filter_failed_map.clear();
1257
194
            }
1258
        }
1259
109
        Ok(())
1260
535
    }
1261
1262
    // Produces and stages record batch for all output indices found
1263
    // for current streamed batch and clears staged output indices.
1264
665
    fn freeze_streamed(&mut self) -> Result<()> {
1265
665
        for 
chunk504
in self.streamed_batch.output_indices.iter_mut() {
1266
            // The row indices of joined streamed batch
1267
504
            let streamed_indices = chunk.streamed_indices.finish();
1268
504
1269
504
            if streamed_indices.is_empty() {
1270
0
                continue;
1271
504
            }
1272
1273
504
            let mut streamed_columns = self
1274
504
                .streamed_batch
1275
504
                .batch
1276
504
                .columns()
1277
504
                .iter()
1278
1.51k
                .map(|column| take(column, &streamed_indices, None))
1279
504
                .collect::<Result<Vec<_>, ArrowError>>()
?0
;
1280
1281
            // The row indices of joined buffered batch
1282
504
            let buffered_indices: UInt64Array = chunk.buffered_indices.finish();
1283
504
            let mut buffered_columns =
1284
504
                if 
matches!468
(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1285
36
                    vec![]
1286
468
                } else if let Some(
buffered_idx459
) = chunk.buffered_batch_idx {
1287
459
                    get_buffered_columns(
1288
459
                        &self.buffered_data,
1289
459
                        buffered_idx,
1290
459
                        &buffered_indices,
1291
459
                    )
?0
1292
                } else {
1293
                    // If buffered batch none, meaning it is null joined batch.
1294
                    // We need to create null arrays for buffered columns to join with streamed rows.
1295
9
                    self.buffered_schema
1296
9
                        .fields()
1297
9
                        .iter()
1298
27
                        .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
1299
9
                        .collect::<Vec<_>>()
1300
                };
1301
1302
504
            let streamed_columns_length = streamed_columns.len();
1303
504
            let buffered_columns_length = buffered_columns.len();
1304
1305
            // Prepare the columns we apply join filter on later.
1306
            // Only for joined rows between streamed and buffered.
1307
504
            let filter_columns = if chunk.buffered_batch_idx.is_some() {
1308
494
                if 
matches!386
(self.join_type, JoinType::Right) {
1309
108
                    get_filter_column(&self.filter, &buffered_columns, &streamed_columns)
1310
351
                } else if matches!(
1311
386
                    self.join_type,
1312
                    JoinType::LeftSemi | JoinType::LeftAnti
1313
                ) {
1314
                    // unwrap is safe here as we check is_some on top of if statement
1315
35
                    let buffered_columns = get_buffered_columns(
1316
35
                        &self.buffered_data,
1317
35
                        chunk.buffered_batch_idx.unwrap(),
1318
35
                        &buffered_indices,
1319
35
                    )
?0
;
1320
1321
35
                    get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
1322
                } else {
1323
351
                    get_filter_column(&self.filter, &streamed_columns, &buffered_columns)
1324
                }
1325
            } else {
1326
                // This chunk is totally for null joined rows (outer join), we don't need to apply join filter.
1327
                // Any join filter applied only on either streamed or buffered side will be pushed already.
1328
10
                vec![]
1329
            };
1330
1331
504
            let columns = if 
matches!391
(self.join_type, JoinType::Right) {
1332
113
                buffered_columns.extend(streamed_columns.clone());
1333
113
                buffered_columns
1334
            } else {
1335
391
                streamed_columns.extend(buffered_columns);
1336
391
                streamed_columns
1337
            };
1338
1339
504
            let output_batch =
1340
504
                RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())
?0
;
1341
1342
            // Apply join filter if any
1343
504
            if !filter_columns.is_empty() {
1344
0
                if let Some(f) = &self.filter {
1345
                    // Construct batch with only filter columns
1346
0
                    let filter_batch = RecordBatch::try_new(
1347
0
                        Arc::new(f.schema().clone()),
1348
0
                        filter_columns,
1349
0
                    )?;
1350
1351
0
                    let filter_result = f
1352
0
                        .expression()
1353
0
                        .evaluate(&filter_batch)?
1354
0
                        .into_array(filter_batch.num_rows())?;
1355
1356
                    // The boolean selection mask of the join filter result
1357
0
                    let pre_mask =
1358
0
                        datafusion_common::cast::as_boolean_array(&filter_result)?;
1359
1360
                    // If there are nulls in join filter result, exclude them from selecting
1361
                    // the rows to output.
1362
0
                    let mask = if pre_mask.null_count() > 0 {
1363
                        compute::prep_null_mask_filter(
1364
0
                            datafusion_common::cast::as_boolean_array(&filter_result)?,
1365
                        )
1366
                    } else {
1367
0
                        pre_mask.clone()
1368
                    };
1369
1370
                    // For certain join types, we need to adjust the initial mask to handle the join filter.
1371
0
                    let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
1372
0
                        get_filtered_join_mask(
1373
0
                            self.join_type,
1374
0
                            &streamed_indices,
1375
0
                            &mask,
1376
0
                            &self.streamed_batch.join_filter_matched_idxs,
1377
0
                            &self.buffered_data.scanning_offset,
1378
0
                        );
1379
1380
0
                    let mask =
1381
0
                        if let Some(ref filtered_join_mask) = maybe_filtered_join_mask {
1382
0
                            self.streamed_batch
1383
0
                                .join_filter_matched_idxs
1384
0
                                .extend(&filtered_join_mask.1);
1385
0
                            &filtered_join_mask.0
1386
                        } else {
1387
0
                            &mask
1388
                        };
1389
1390
                    // Push the filtered batch which contains rows passing join filter to the output
1391
0
                    let filtered_batch =
1392
0
                        compute::filter_record_batch(&output_batch, mask)?;
1393
0
                    self.output_record_batches.push(filtered_batch);
1394
1395
                    // For outer joins, we need to push the null joined rows to the output if
1396
                    // all joined rows are failed on the join filter.
1397
                    // I.e., if all rows joined from a streamed row are failed with the join filter,
1398
                    // we need to join it with nulls as buffered side.
1399
0
                    if matches!(
1400
0
                        self.join_type,
1401
                        JoinType::Left | JoinType::Right | JoinType::Full
1402
                    ) {
1403
                        // We need to get the mask for row indices that the joined rows are failed
1404
                        // on the join filter. I.e., for a row in streamed side, if all joined rows
1405
                        // between it and all buffered rows are failed on the join filter, we need to
1406
                        // output it with null columns from buffered side. For the mask here, it
1407
                        // behaves like LeftAnti join.
1408
0
                        let null_mask: BooleanArray = get_filtered_join_mask(
1409
0
                            // Set a mask slot as true only if all joined rows of same streamed index
1410
0
                            // are failed on the join filter.
1411
0
                            // The masking behavior is like LeftAnti join.
1412
0
                            JoinType::LeftAnti,
1413
0
                            &streamed_indices,
1414
0
                            mask,
1415
0
                            &self.streamed_batch.join_filter_matched_idxs,
1416
0
                            &self.buffered_data.scanning_offset,
1417
0
                        )
1418
0
                        .unwrap()
1419
0
                        .0;
1420
1421
0
                        let null_joined_batch =
1422
0
                            compute::filter_record_batch(&output_batch, &null_mask)?;
1423
1424
0
                        let mut buffered_columns = self
1425
0
                            .buffered_schema
1426
0
                            .fields()
1427
0
                            .iter()
1428
0
                            .map(|f| {
1429
0
                                new_null_array(
1430
0
                                    f.data_type(),
1431
0
                                    null_joined_batch.num_rows(),
1432
0
                                )
1433
0
                            })
1434
0
                            .collect::<Vec<_>>();
1435
1436
0
                        let columns = if matches!(self.join_type, JoinType::Right) {
1437
0
                            let streamed_columns = null_joined_batch
1438
0
                                .columns()
1439
0
                                .iter()
1440
0
                                .skip(buffered_columns_length)
1441
0
                                .cloned()
1442
0
                                .collect::<Vec<_>>();
1443
0
1444
0
                            buffered_columns.extend(streamed_columns);
1445
0
                            buffered_columns
1446
                        } else {
1447
                            // Left join or full outer join
1448
0
                            let mut streamed_columns = null_joined_batch
1449
0
                                .columns()
1450
0
                                .iter()
1451
0
                                .take(streamed_columns_length)
1452
0
                                .cloned()
1453
0
                                .collect::<Vec<_>>();
1454
0
1455
0
                            streamed_columns.extend(buffered_columns);
1456
0
                            streamed_columns
1457
                        };
1458
1459
                        // Push the streamed/buffered batch joined nulls to the output
1460
0
                        let null_joined_streamed_batch = RecordBatch::try_new(
1461
0
                            Arc::clone(&self.schema),
1462
0
                            columns.clone(),
1463
0
                        )?;
1464
0
                        self.output_record_batches.push(null_joined_streamed_batch);
1465
1466
                        // For full join, we also need to output the null joined rows from the buffered side.
1467
                        // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with
1468
                        // streamed side, it won't be outputted by `freeze_buffered`.
1469
                        // We need to check if a buffered row is joined with streamed side and output.
1470
                        // If it is joined with streamed side, but doesn't match the join filter,
1471
                        // we need to output it with nulls as streamed side.
1472
0
                        if matches!(self.join_type, JoinType::Full) {
1473
0
                            let buffered_batch = &mut self.buffered_data.batches
1474
0
                                [chunk.buffered_batch_idx.unwrap()];
1475
1476
0
                            for i in 0..pre_mask.len() {
1477
                                // If the buffered row is not joined with streamed side,
1478
                                // skip it.
1479
0
                                if buffered_indices.is_null(i) {
1480
0
                                    continue;
1481
0
                                }
1482
0
1483
0
                                let buffered_index = buffered_indices.value(i);
1484
0
1485
0
                                buffered_batch.join_filter_failed_map.insert(
1486
0
                                    buffered_index,
1487
0
                                    *buffered_batch
1488
0
                                        .join_filter_failed_map
1489
0
                                        .get(&buffered_index)
1490
0
                                        .unwrap_or(&true)
1491
0
                                        && !pre_mask.value(i),
1492
                                );
1493
                            }
1494
0
                        }
1495
0
                    }
1496
0
                } else {
1497
0
                    self.output_record_batches.push(output_batch);
1498
0
                }
1499
504
            } else {
1500
504
                self.output_record_batches.push(output_batch);
1501
504
            }
1502
        }
1503
1504
665
        self.streamed_batch.output_indices.clear();
1505
665
1506
665
        Ok(())
1507
665
    }
1508
1509
350
    fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
1510
350
        let record_batch = concat_batches(&self.schema, &self.output_record_batches)
?0
;
1511
350
        self.join_metrics.output_batches.add(1);
1512
350
        self.join_metrics.output_rows.add(record_batch.num_rows());
1513
350
        // If join filter exists, `self.output_size` is not accurate as we don't know the exact
1514
350
        // number of rows in the output record batch. If streamed row joined with buffered rows,
1515
350
        // once join filter is applied, the number of output rows may be more than 1.
1516
350
        // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened
1517
350
        // when the join filter is applied and all rows are filtered out.
1518
350
        if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size {
1519
0
            self.output_size = 0;
1520
350
        } else {
1521
350
            self.output_size -= record_batch.num_rows();
1522
350
        }
1523
350
        self.output_record_batches.clear();
1524
350
        Ok(record_batch)
1525
350
    }
1526
}
1527
1528
/// Gets the arrays which join filters are applied on.
1529
494
fn get_filter_column(
1530
494
    join_filter: &Option<JoinFilter>,
1531
494
    streamed_columns: &[ArrayRef],
1532
494
    buffered_columns: &[ArrayRef],
1533
494
) -> Vec<ArrayRef> {
1534
494
    let mut filter_columns = vec![];
1535
1536
494
    if let Some(
f0
) = join_filter {
1537
0
        let left_columns = f
1538
0
            .column_indices()
1539
0
            .iter()
1540
0
            .filter(|col_index| col_index.side == JoinSide::Left)
1541
0
            .map(|i| Arc::clone(&streamed_columns[i.index]))
1542
0
            .collect::<Vec<_>>();
1543
0
1544
0
        let right_columns = f
1545
0
            .column_indices()
1546
0
            .iter()
1547
0
            .filter(|col_index| col_index.side == JoinSide::Right)
1548
0
            .map(|i| Arc::clone(&buffered_columns[i.index]))
1549
0
            .collect::<Vec<_>>();
1550
0
1551
0
        filter_columns.extend(left_columns);
1552
0
        filter_columns.extend(right_columns);
1553
494
    }
1554
1555
494
    filter_columns
1556
494
}
1557
1558
232
fn produce_buffered_null_batch(
1559
232
    schema: &SchemaRef,
1560
232
    streamed_schema: &SchemaRef,
1561
232
    buffered_indices: &PrimitiveArray<UInt64Type>,
1562
232
    buffered_batch: &BufferedBatch,
1563
232
) -> Result<Option<RecordBatch>> {
1564
232
    if buffered_indices.is_empty() {
1565
225
        return Ok(None);
1566
7
    }
1567
1568
    // Take buffered (right) columns
1569
7
    let buffered_columns =
1570
7
        get_buffered_columns_from_batch(buffered_batch, buffered_indices)
?0
;
1571
1572
    // Create null streamed (left) columns
1573
7
    let mut streamed_columns = streamed_schema
1574
7
        .fields()
1575
7
        .iter()
1576
21
        .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
1577
7
        .collect::<Vec<_>>();
1578
7
1579
7
    streamed_columns.extend(buffered_columns);
1580
7
1581
7
    Ok(Some(RecordBatch::try_new(
1582
7
        Arc::clone(schema),
1583
7
        streamed_columns,
1584
7
    )
?0
))
1585
232
}
1586
1587
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]`
1588
#[inline(always)]
1589
494
fn get_buffered_columns(
1590
494
    buffered_data: &BufferedData,
1591
494
    buffered_batch_idx: usize,
1592
494
    buffered_indices: &UInt64Array,
1593
494
) -> Result<Vec<ArrayRef>> {
1594
494
    get_buffered_columns_from_batch(
1595
494
        &buffered_data.batches[buffered_batch_idx],
1596
494
        buffered_indices,
1597
494
    )
1598
494
}
1599
1600
#[inline(always)]
1601
501
fn get_buffered_columns_from_batch(
1602
501
    buffered_batch: &BufferedBatch,
1603
501
    buffered_indices: &UInt64Array,
1604
501
) -> Result<Vec<ArrayRef>> {
1605
501
    match (&buffered_batch.spill_file, &buffered_batch.batch) {
1606
        // In memory batch
1607
347
        (None, Some(batch)) => Ok(batch
1608
347
            .columns()
1609
347
            .iter()
1610
1.04k
            .map(|column| take(column, &buffered_indices, None))
1611
347
            .collect::<Result<Vec<_>, ArrowError>>()
1612
347
            .map_err(Into::<DataFusionError>::into)
?0
),
1613
        // If the batch was spilled to disk, less likely
1614
154
        (Some(spill_file), None) => {
1615
154
            let mut buffered_cols: Vec<ArrayRef> =
1616
154
                Vec::with_capacity(buffered_indices.len());
1617
1618
154
            let file = BufReader::new(File::open(spill_file.path())
?0
);
1619
154
            let reader = FileReader::try_new(file, None)
?0
;
1620
1621
308
            for 
batch154
in reader {
1622
462
                
batch154
?0
.
columns().iter().for_each(154
|column| {
1623
462
                    buffered_cols.extend(take(column, &buffered_indices, None))
1624
462
                });
1625
154
            }
1626
1627
154
            Ok(buffered_cols)
1628
        }
1629
        // Invalid combination
1630
0
        (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()),
1631
    }
1632
501
}
1633
1634
/// Calculate join filter bit mask considering join type specifics
1635
/// `streamed_indices` - array of streamed datasource JOINED row indices
1636
/// `mask` - array booleans representing computed join filter expression eval result:
1637
///      true = the row index matches the join filter
1638
///      false = the row index doesn't match the join filter
1639
/// `streamed_indices` have the same length as `mask`
1640
/// `matched_indices` array of streaming indices that already has a join filter match
1641
/// `scanning_buffered_offset` current buffered offset across batches
1642
///
1643
/// This return a tuple of:
1644
/// - corrected mask with respect to the join type
1645
/// - indices of rows in streamed batch that have a join filter match
1646
13
fn get_filtered_join_mask(
1647
13
    join_type: JoinType,
1648
13
    streamed_indices: &UInt64Array,
1649
13
    mask: &BooleanArray,
1650
13
    matched_indices: &HashSet<u64>,
1651
13
    scanning_buffered_offset: &usize,
1652
13
) -> Option<(BooleanArray, Vec<u64>)> {
1653
13
    let mut seen_as_true: bool = false;
1654
13
    let streamed_indices_length = streamed_indices.len();
1655
13
    let mut corrected_mask: BooleanBuilder =
1656
13
        BooleanBuilder::with_capacity(streamed_indices_length);
1657
13
1658
13
    let mut filter_matched_indices: Vec<u64> = vec![];
1659
13
1660
13
    #[allow(clippy::needless_range_loop)]
1661
13
    match join_type {
1662
        // for LeftSemi Join the filter mask should be calculated in its own way:
1663
        // if we find at least one matching row for specific streaming index
1664
        // we don't need to check any others for the same index
1665
        JoinType::LeftSemi => {
1666
            // have we seen a filter match for a streaming index before
1667
28
            for i in 0..
streamed_indices_length7
{
1668
                // LeftSemi respects only first true values for specific streaming index,
1669
                // others true values for the same index must be false
1670
28
                let streamed_idx = streamed_indices.value(i);
1671
28
                if mask.value(i)
1672
14
                    && !seen_as_true
1673
10
                    && !matched_indices.contains(&streamed_idx)
1674
9
                {
1675
9
                    seen_as_true = true;
1676
9
                    corrected_mask.append_value(true);
1677
9
                    filter_matched_indices.push(streamed_idx);
1678
19
                } else {
1679
19
                    corrected_mask.append_value(false);
1680
19
                }
1681
1682
                // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag
1683
28
                if i < streamed_indices_length - 1
1684
21
                    && streamed_idx != streamed_indices.value(i + 1)
1685
7
                {
1686
7
                    seen_as_true = false;
1687
21
                }
1688
            }
1689
7
            Some((corrected_mask.finish(), filter_matched_indices))
1690
        }
1691
        // LeftAnti semantics: return true if for every x in the collection the join matching filter is false.
1692
        // `filter_matched_indices` needs to be set once per streaming index
1693
        // to prevent duplicates in the output
1694
        JoinType::LeftAnti => {
1695
            // have we seen a filter match for a streaming index before
1696
22
            for i in 0..
streamed_indices_length6
{
1697
22
                let streamed_idx = streamed_indices.value(i);
1698
22
                if mask.value(i)
1699
12
                    && !seen_as_true
1700
8
                    && !matched_indices.contains(&streamed_idx)
1701
8
                {
1702
8
                    seen_as_true = true;
1703
8
                    filter_matched_indices.push(streamed_idx);
1704
14
                }
1705
1706
                // Reset `seen_as_true` flag and calculate mask for the current streaming index
1707
                // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2)
1708
                // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last
1709
22
                if (i < streamed_indices_length - 1
1710
16
                    && streamed_idx != streamed_indices.value(i + 1))
1711
16
                    || (i == streamed_indices_length - 1
1712
6
                        && *scanning_buffered_offset == 0)
1713
                {
1714
12
                    corrected_mask.append_value(
1715
12
                        !matched_indices.contains(&streamed_idx) && !seen_as_true,
1716
                    );
1717
12
                    seen_as_true = false;
1718
10
                } else {
1719
10
                    corrected_mask.append_value(false);
1720
10
                }
1721
            }
1722
1723
6
            Some((corrected_mask.finish(), filter_matched_indices))
1724
        }
1725
0
        _ => None,
1726
    }
1727
13
}
1728
1729
/// Buffered data contains all buffered batches with one unique join key
1730
#[derive(Debug, Default)]
1731
struct BufferedData {
1732
    /// Buffered batches with the same key
1733
    pub batches: VecDeque<BufferedBatch>,
1734
    /// current scanning batch index used in join_partial()
1735
    pub scanning_batch_idx: usize,
1736
    /// current scanning offset used in join_partial()
1737
    pub scanning_offset: usize,
1738
}
1739
1740
impl BufferedData {
1741
1.71k
    pub fn head_batch(&self) -> &BufferedBatch {
1742
1.71k
        self.batches.front().unwrap()
1743
1.71k
    }
1744
1745
1.80k
    pub fn tail_batch(&self) -> &BufferedBatch {
1746
1.80k
        self.batches.back().unwrap()
1747
1.80k
    }
1748
1749
249
    pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch {
1750
249
        self.batches.back_mut().unwrap()
1751
249
    }
1752
1753
488
    pub fn has_buffered_rows(&self) -> bool {
1754
488
        self.batches.iter().any(|batch| 
!batch.range.is_empty()476
)
1755
488
    }
1756
1757
569
    pub fn scanning_reset(&mut self) {
1758
569
        self.scanning_batch_idx = 0;
1759
569
        self.scanning_offset = 0;
1760
569
    }
1761
1762
607
    pub fn scanning_advance(&mut self) {
1763
607
        self.scanning_offset += 1;
1764
1.00k
        while !self.scanning_finished() && 
self.scanning_batch_finished()797
{
1765
399
            self.scanning_batch_idx += 1;
1766
399
            self.scanning_offset = 0;
1767
399
        }
1768
607
    }
1769
1770
1.40k
    pub fn scanning_batch(&self) -> &BufferedBatch {
1771
1.40k
        &self.batches[self.scanning_batch_idx]
1772
1.40k
    }
1773
1774
7
    pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
1775
7
        &mut self.batches[self.scanning_batch_idx]
1776
7
    }
1777
1778
607
    pub fn scanning_idx(&self) -> usize {
1779
607
        self.scanning_batch().range.start + self.scanning_offset
1780
607
    }
1781
1782
797
    pub fn scanning_batch_finished(&self) -> bool {
1783
797
        self.scanning_offset == self.scanning_batch().range.len()
1784
797
    }
1785
1786
3.36k
    pub fn scanning_finished(&self) -> bool {
1787
3.36k
        self.scanning_batch_idx == self.batches.len()
1788
3.36k
    }
1789
1790
396
    pub fn scanning_finish(&mut self) {
1791
396
        self.scanning_batch_idx = self.batches.len();
1792
396
        self.scanning_offset = 0;
1793
396
    }
1794
}
1795
1796
/// Get join array refs of given batch and join columns
1797
260
fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayRef> {
1798
260
    on_column
1799
260
        .iter()
1800
270
        .map(|c| {
1801
270
            let num_rows = batch.num_rows();
1802
270
            let c = c.evaluate(batch).unwrap();
1803
270
            c.into_array(num_rows).unwrap()
1804
270
        })
1805
260
        .collect()
1806
260
}
1807
1808
/// Get comparison result of two rows of join arrays
1809
476
fn compare_join_arrays(
1810
476
    left_arrays: &[ArrayRef],
1811
476
    left: usize,
1812
476
    right_arrays: &[ArrayRef],
1813
476
    right: usize,
1814
476
    sort_options: &[SortOptions],
1815
476
    null_equals_null: bool,
1816
476
) -> Result<Ordering> {
1817
476
    let mut res = Ordering::Equal;
1818
494
    for ((left_array, right_array), sort_options) in
1819
476
        left_arrays.iter().zip(right_arrays).zip(sort_options)
1820
    {
1821
        macro_rules! compare_value {
1822
            ($T:ty) => {{
1823
                let left_array = left_array.as_any().downcast_ref::<$T>().unwrap();
1824
                let right_array = right_array.as_any().downcast_ref::<$T>().unwrap();
1825
                match (left_array.is_null(left), right_array.is_null(right)) {
1826
                    (false, false) => {
1827
                        let left_value = &left_array.value(left);
1828
                        let right_value = &right_array.value(right);
1829
                        res = left_value.partial_cmp(right_value).unwrap();
1830
                        if sort_options.descending {
1831
                            res = res.reverse();
1832
                        }
1833
                    }
1834
                    (true, false) => {
1835
                        res = if sort_options.nulls_first {
1836
                            Ordering::Less
1837
                        } else {
1838
                            Ordering::Greater
1839
                        };
1840
                    }
1841
                    (false, true) => {
1842
                        res = if sort_options.nulls_first {
1843
                            Ordering::Greater
1844
                        } else {
1845
                            Ordering::Less
1846
                        };
1847
                    }
1848
                    _ => {
1849
                        res = if null_equals_null {
1850
                            Ordering::Equal
1851
                        } else {
1852
                            Ordering::Less
1853
                        };
1854
                    }
1855
                }
1856
            }};
1857
        }
1858
1859
494
        match left_array.data_type() {
1860
0
            DataType::Null => {}
1861
0
            DataType::Boolean => compare_value!(BooleanArray),
1862
0
            DataType::Int8 => compare_value!(Int8Array),
1863
0
            DataType::Int16 => compare_value!(Int16Array),
1864
485
            DataType::Int32 => compare_value!(Int32Array),
1865
0
            DataType::Int64 => compare_value!(Int64Array),
1866
0
            DataType::UInt8 => compare_value!(UInt8Array),
1867
0
            DataType::UInt16 => compare_value!(UInt16Array),
1868
0
            DataType::UInt32 => compare_value!(UInt32Array),
1869
0
            DataType::UInt64 => compare_value!(UInt64Array),
1870
0
            DataType::Float32 => compare_value!(Float32Array),
1871
0
            DataType::Float64 => compare_value!(Float64Array),
1872
0
            DataType::Utf8 => compare_value!(StringArray),
1873
0
            DataType::LargeUtf8 => compare_value!(LargeStringArray),
1874
0
            DataType::Decimal128(..) => compare_value!(Decimal128Array),
1875
0
            DataType::Timestamp(time_unit, None) => match time_unit {
1876
0
                TimeUnit::Second => compare_value!(TimestampSecondArray),
1877
0
                TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
1878
0
                TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
1879
0
                TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
1880
            },
1881
4
            DataType::Date32 => compare_value!(Date32Array),
1882
5
            DataType::Date64 => compare_value!(Date64Array),
1883
0
            dt => {
1884
0
                return not_impl_err!(
1885
0
                    "Unsupported data type in sort merge join comparator: {}",
1886
0
                    dt
1887
0
                );
1888
            }
1889
        }
1890
494
        if !res.is_eq() {
1891
188
            break;
1892
306
        }
1893
    }
1894
476
    Ok(res)
1895
476
}
1896
1897
/// A faster version of compare_join_arrays() that only output whether
1898
/// the given two rows are equal
1899
249
fn is_join_arrays_equal(
1900
249
    left_arrays: &[ArrayRef],
1901
249
    left: usize,
1902
249
    right_arrays: &[ArrayRef],
1903
249
    right: usize,
1904
249
) -> Result<bool> {
1905
249
    let mut is_equal = true;
1906
252
    for (left_array, right_array) in 
left_arrays.iter().zip(right_arrays)249
{
1907
        macro_rules! compare_value {
1908
            ($T:ty) => {{
1909
                match (left_array.is_null(left), right_array.is_null(right)) {
1910
                    (false, false) => {
1911
                        let left_array =
1912
                            left_array.as_any().downcast_ref::<$T>().unwrap();
1913
                        let right_array =
1914
                            right_array.as_any().downcast_ref::<$T>().unwrap();
1915
                        if left_array.value(left) != right_array.value(right) {
1916
                            is_equal = false;
1917
                        }
1918
                    }
1919
                    (true, false) => is_equal = false,
1920
                    (false, true) => is_equal = false,
1921
                    _ => {}
1922
                }
1923
            }};
1924
        }
1925
1926
252
        match left_array.data_type() {
1927
0
            DataType::Null => {}
1928
0
            DataType::Boolean => compare_value!(BooleanArray),
1929
0
            DataType::Int8 => compare_value!(Int8Array),
1930
0
            DataType::Int16 => compare_value!(Int16Array),
1931
248
            DataType::Int32 => compare_value!(Int32Array),
1932
0
            DataType::Int64 => compare_value!(Int64Array),
1933
0
            DataType::UInt8 => compare_value!(UInt8Array),
1934
0
            DataType::UInt16 => compare_value!(UInt16Array),
1935
0
            DataType::UInt32 => compare_value!(UInt32Array),
1936
0
            DataType::UInt64 => compare_value!(UInt64Array),
1937
0
            DataType::Float32 => compare_value!(Float32Array),
1938
0
            DataType::Float64 => compare_value!(Float64Array),
1939
0
            DataType::Utf8 => compare_value!(StringArray),
1940
0
            DataType::LargeUtf8 => compare_value!(LargeStringArray),
1941
0
            DataType::Decimal128(..) => compare_value!(Decimal128Array),
1942
0
            DataType::Timestamp(time_unit, None) => match time_unit {
1943
0
                TimeUnit::Second => compare_value!(TimestampSecondArray),
1944
0
                TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
1945
0
                TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
1946
0
                TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
1947
            },
1948
2
            DataType::Date32 => compare_value!(Date32Array),
1949
2
            DataType::Date64 => compare_value!(Date64Array),
1950
0
            dt => {
1951
0
                return not_impl_err!(
1952
0
                    "Unsupported data type in sort merge join comparator: {}",
1953
0
                    dt
1954
0
                );
1955
            }
1956
        }
1957
252
        if !is_equal {
1958
144
            return Ok(false);
1959
108
        }
1960
    }
1961
105
    Ok(true)
1962
249
}
1963
1964
#[cfg(test)]
1965
mod tests {
1966
    use std::sync::Arc;
1967
1968
    use arrow::array::{Date32Array, Date64Array, Int32Array};
1969
    use arrow::compute::SortOptions;
1970
    use arrow::datatypes::{DataType, Field, Schema};
1971
    use arrow::record_batch::RecordBatch;
1972
    use arrow_array::{BooleanArray, UInt64Array};
1973
    use hashbrown::HashSet;
1974
1975
    use datafusion_common::JoinType::{LeftAnti, LeftSemi};
1976
    use datafusion_common::{
1977
        assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result,
1978
    };
1979
    use datafusion_execution::config::SessionConfig;
1980
    use datafusion_execution::disk_manager::DiskManagerConfig;
1981
    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1982
    use datafusion_execution::TaskContext;
1983
1984
    use crate::expressions::Column;
1985
    use crate::joins::sort_merge_join::get_filtered_join_mask;
1986
    use crate::joins::utils::JoinOn;
1987
    use crate::joins::SortMergeJoinExec;
1988
    use crate::memory::MemoryExec;
1989
    use crate::test::build_table_i32;
1990
    use crate::{common, ExecutionPlan};
1991
1992
28
    fn build_table(
1993
28
        a: (&str, &Vec<i32>),
1994
28
        b: (&str, &Vec<i32>),
1995
28
        c: (&str, &Vec<i32>),
1996
28
    ) -> Arc<dyn ExecutionPlan> {
1997
28
        let batch = build_table_i32(a, b, c);
1998
28
        let schema = batch.schema();
1999
28
        Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
2000
28
    }
2001
2002
10
    fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> {
2003
10
        let schema = batches.first().unwrap().schema();
2004
10
        Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap())
2005
10
    }
2006
2007
2
    fn build_date_table(
2008
2
        a: (&str, &Vec<i32>),
2009
2
        b: (&str, &Vec<i32>),
2010
2
        c: (&str, &Vec<i32>),
2011
2
    ) -> Arc<dyn ExecutionPlan> {
2012
2
        let schema = Schema::new(vec![
2013
2
            Field::new(a.0, DataType::Date32, false),
2014
2
            Field::new(b.0, DataType::Date32, false),
2015
2
            Field::new(c.0, DataType::Date32, false),
2016
2
        ]);
2017
2
2018
2
        let batch = RecordBatch::try_new(
2019
2
            Arc::new(schema),
2020
2
            vec![
2021
2
                Arc::new(Date32Array::from(a.1.clone())),
2022
2
                Arc::new(Date32Array::from(b.1.clone())),
2023
2
                Arc::new(Date32Array::from(c.1.clone())),
2024
2
            ],
2025
2
        )
2026
2
        .unwrap();
2027
2
2028
2
        let schema = batch.schema();
2029
2
        Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
2030
2
    }
2031
2032
2
    fn build_date64_table(
2033
2
        a: (&str, &Vec<i64>),
2034
2
        b: (&str, &Vec<i64>),
2035
2
        c: (&str, &Vec<i64>),
2036
2
    ) -> Arc<dyn ExecutionPlan> {
2037
2
        let schema = Schema::new(vec![
2038
2
            Field::new(a.0, DataType::Date64, false),
2039
2
            Field::new(b.0, DataType::Date64, false),
2040
2
            Field::new(c.0, DataType::Date64, false),
2041
2
        ]);
2042
2
2043
2
        let batch = RecordBatch::try_new(
2044
2
            Arc::new(schema),
2045
2
            vec![
2046
2
                Arc::new(Date64Array::from(a.1.clone())),
2047
2
                Arc::new(Date64Array::from(b.1.clone())),
2048
2
                Arc::new(Date64Array::from(c.1.clone())),
2049
2
            ],
2050
2
        )
2051
2
        .unwrap();
2052
2
2053
2
        let schema = batch.schema();
2054
2
        Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
2055
2
    }
2056
2057
    /// returns a table with 3 columns of i32 in memory
2058
4
    pub fn build_table_i32_nullable(
2059
4
        a: (&str, &Vec<Option<i32>>),
2060
4
        b: (&str, &Vec<Option<i32>>),
2061
4
        c: (&str, &Vec<Option<i32>>),
2062
4
    ) -> Arc<dyn ExecutionPlan> {
2063
4
        let schema = Arc::new(Schema::new(vec![
2064
4
            Field::new(a.0, DataType::Int32, true),
2065
4
            Field::new(b.0, DataType::Int32, true),
2066
4
            Field::new(c.0, DataType::Int32, true),
2067
4
        ]));
2068
4
        let batch = RecordBatch::try_new(
2069
4
            Arc::clone(&schema),
2070
4
            vec![
2071
4
                Arc::new(Int32Array::from(a.1.clone())),
2072
4
                Arc::new(Int32Array::from(b.1.clone())),
2073
4
                Arc::new(Int32Array::from(c.1.clone())),
2074
4
            ],
2075
4
        )
2076
4
        .unwrap();
2077
4
        Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
2078
4
    }
2079
2080
1
    fn join(
2081
1
        left: Arc<dyn ExecutionPlan>,
2082
1
        right: Arc<dyn ExecutionPlan>,
2083
1
        on: JoinOn,
2084
1
        join_type: JoinType,
2085
1
    ) -> Result<SortMergeJoinExec> {
2086
1
        let sort_options = vec![SortOptions::default(); on.len()];
2087
1
        SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false)
2088
1
    }
2089
2090
78
    fn join_with_options(
2091
78
        left: Arc<dyn ExecutionPlan>,
2092
78
        right: Arc<dyn ExecutionPlan>,
2093
78
        on: JoinOn,
2094
78
        join_type: JoinType,
2095
78
        sort_options: Vec<SortOptions>,
2096
78
        null_equals_null: bool,
2097
78
    ) -> Result<SortMergeJoinExec> {
2098
78
        SortMergeJoinExec::try_new(
2099
78
            left,
2100
78
            right,
2101
78
            on,
2102
78
            None,
2103
78
            join_type,
2104
78
            sort_options,
2105
78
            null_equals_null,
2106
78
        )
2107
78
    }
2108
2109
17
    async fn join_collect(
2110
17
        left: Arc<dyn ExecutionPlan>,
2111
17
        right: Arc<dyn ExecutionPlan>,
2112
17
        on: JoinOn,
2113
17
        join_type: JoinType,
2114
17
    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2115
17
        let sort_options = vec![SortOptions::default(); on.len()];
2116
17
        join_collect_with_options(left, right, on, join_type, sort_options, false).
await0
2117
17
    }
2118
2119
18
    async fn join_collect_with_options(
2120
18
        left: Arc<dyn ExecutionPlan>,
2121
18
        right: Arc<dyn ExecutionPlan>,
2122
18
        on: JoinOn,
2123
18
        join_type: JoinType,
2124
18
        sort_options: Vec<SortOptions>,
2125
18
        null_equals_null: bool,
2126
18
    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2127
18
        let task_ctx = Arc::new(TaskContext::default());
2128
18
        let join = join_with_options(
2129
18
            left,
2130
18
            right,
2131
18
            on,
2132
18
            join_type,
2133
18
            sort_options,
2134
18
            null_equals_null,
2135
18
        )
?0
;
2136
18
        let columns = columns(&join.schema());
2137
2138
18
        let stream = join.execute(0, task_ctx)
?0
;
2139
18
        let batches = common::collect(stream).
await0
?0
;
2140
18
        Ok((columns, batches))
2141
18
    }
2142
2143
1
    async fn join_collect_batch_size_equals_two(
2144
1
        left: Arc<dyn ExecutionPlan>,
2145
1
        right: Arc<dyn ExecutionPlan>,
2146
1
        on: JoinOn,
2147
1
        join_type: JoinType,
2148
1
    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
2149
1
        let task_ctx = TaskContext::default()
2150
1
            .with_session_config(SessionConfig::new().with_batch_size(2));
2151
1
        let task_ctx = Arc::new(task_ctx);
2152
1
        let join = join(left, right, on, join_type)
?0
;
2153
1
        let columns = columns(&join.schema());
2154
2155
1
        let stream = join.execute(0, task_ctx)
?0
;
2156
1
        let batches = common::collect(stream).
await0
?0
;
2157
1
        Ok((columns, batches))
2158
1
    }
2159
2160
    #[tokio::test]
2161
1
    async fn join_inner_one() -> Result<()> {
2162
1
        let left = build_table(
2163
1
            ("a1", &vec![1, 2, 3]),
2164
1
            ("b1", &vec![4, 5, 5]), // this has a repetition
2165
1
            ("c1", &vec![7, 8, 9]),
2166
1
        );
2167
1
        let right = build_table(
2168
1
            ("a2", &vec![10, 20, 30]),
2169
1
            ("b1", &vec![4, 5, 6]),
2170
1
            ("c2", &vec![70, 80, 90]),
2171
1
        );
2172
1
2173
1
        let on = vec![(
2174
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2175
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2176
1
        )];
2177
1
2178
1
        let (_, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2179
1
2180
1
        let expected = [
2181
1
            "+----+----+----+----+----+----+",
2182
1
            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2183
1
            "+----+----+----+----+----+----+",
2184
1
            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2185
1
            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2186
1
            "| 3  | 5  | 9  | 20 | 5  | 80 |",
2187
1
            "+----+----+----+----+----+----+",
2188
1
        ];
2189
1
        // The output order is important as SMJ preserves sortedness
2190
1
        assert_batches_eq!(expected, &batches);
2191
1
        Ok(())
2192
1
    }
2193
2194
    #[tokio::test]
2195
1
    async fn join_inner_two() -> Result<()> {
2196
1
        let left = build_table(
2197
1
            ("a1", &vec![1, 2, 2]),
2198
1
            ("b2", &vec![1, 2, 2]),
2199
1
            ("c1", &vec![7, 8, 9]),
2200
1
        );
2201
1
        let right = build_table(
2202
1
            ("a1", &vec![1, 2, 3]),
2203
1
            ("b2", &vec![1, 2, 2]),
2204
1
            ("c2", &vec![70, 80, 90]),
2205
1
        );
2206
1
        let on = vec![
2207
1
            (
2208
1
                Arc::new(Column::new_with_schema("a1", &left.schema())
?0
) as _,
2209
1
                Arc::new(Column::new_with_schema("a1", &right.schema())
?0
) as _,
2210
1
            ),
2211
1
            (
2212
1
                Arc::new(Column::new_with_schema("b2", &left.schema())
?0
) as _,
2213
1
                Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2214
1
            ),
2215
1
        ];
2216
1
2217
1
        let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2218
1
        let expected = [
2219
1
            "+----+----+----+----+----+----+",
2220
1
            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2221
1
            "+----+----+----+----+----+----+",
2222
1
            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2223
1
            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2224
1
            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2225
1
            "+----+----+----+----+----+----+",
2226
1
        ];
2227
1
        // The output order is important as SMJ preserves sortedness
2228
1
        assert_batches_eq!(expected, &batches);
2229
1
        Ok(())
2230
1
    }
2231
2232
    #[tokio::test]
2233
1
    async fn join_inner_two_two() -> Result<()> {
2234
1
        let left = build_table(
2235
1
            ("a1", &vec![1, 1, 2]),
2236
1
            ("b2", &vec![1, 1, 2]),
2237
1
            ("c1", &vec![7, 8, 9]),
2238
1
        );
2239
1
        let right = build_table(
2240
1
            ("a1", &vec![1, 1, 3]),
2241
1
            ("b2", &vec![1, 1, 2]),
2242
1
            ("c2", &vec![70, 80, 90]),
2243
1
        );
2244
1
        let on = vec![
2245
1
            (
2246
1
                Arc::new(Column::new_with_schema("a1", &left.schema())
?0
) as _,
2247
1
                Arc::new(Column::new_with_schema("a1", &right.schema())
?0
) as _,
2248
1
            ),
2249
1
            (
2250
1
                Arc::new(Column::new_with_schema("b2", &left.schema())
?0
) as _,
2251
1
                Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2252
1
            ),
2253
1
        ];
2254
1
2255
1
        let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2256
1
        let expected = [
2257
1
            "+----+----+----+----+----+----+",
2258
1
            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2259
1
            "+----+----+----+----+----+----+",
2260
1
            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2261
1
            "| 1  | 1  | 7  | 1  | 1  | 80 |",
2262
1
            "| 1  | 1  | 8  | 1  | 1  | 70 |",
2263
1
            "| 1  | 1  | 8  | 1  | 1  | 80 |",
2264
1
            "+----+----+----+----+----+----+",
2265
1
        ];
2266
1
        // The output order is important as SMJ preserves sortedness
2267
1
        assert_batches_eq!(expected, &batches);
2268
1
        Ok(())
2269
1
    }
2270
2271
    #[tokio::test]
2272
1
    async fn join_inner_with_nulls() -> Result<()> {
2273
1
        let left = build_table_i32_nullable(
2274
1
            ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]),
2275
1
            ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field
2276
1
            ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field
2277
1
        );
2278
1
        let right = build_table_i32_nullable(
2279
1
            ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]),
2280
1
            ("b2", &vec![None, Some(1), Some(2), Some(2)]),
2281
1
            ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]),
2282
1
        );
2283
1
        let on = vec![
2284
1
            (
2285
1
                Arc::new(Column::new_with_schema("a1", &left.schema())
?0
) as _,
2286
1
                Arc::new(Column::new_with_schema("a1", &right.schema())
?0
) as _,
2287
1
            ),
2288
1
            (
2289
1
                Arc::new(Column::new_with_schema("b2", &left.schema())
?0
) as _,
2290
1
                Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2291
1
            ),
2292
1
        ];
2293
1
2294
1
        let (_, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2295
1
        let expected = [
2296
1
            "+----+----+----+----+----+----+",
2297
1
            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2298
1
            "+----+----+----+----+----+----+",
2299
1
            "| 1  | 1  |    | 1  | 1  | 70 |",
2300
1
            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2301
1
            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2302
1
            "+----+----+----+----+----+----+",
2303
1
        ];
2304
1
        // The output order is important as SMJ preserves sortedness
2305
1
        assert_batches_eq!(expected, &batches);
2306
1
        Ok(())
2307
1
    }
2308
2309
    #[tokio::test]
2310
1
    async fn join_inner_with_nulls_with_options() -> Result<()> {
2311
1
        let left = build_table_i32_nullable(
2312
1
            ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]),
2313
1
            ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field
2314
1
            ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field
2315
1
        );
2316
1
        let right = build_table_i32_nullable(
2317
1
            ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]),
2318
1
            ("b2", &vec![Some(2), Some(2), Some(1), None]),
2319
1
            ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]),
2320
1
        );
2321
1
        let on = vec![
2322
1
            (
2323
1
                Arc::new(Column::new_with_schema("a1", &left.schema())
?0
) as _,
2324
1
                Arc::new(Column::new_with_schema("a1", &right.schema())
?0
) as _,
2325
1
            ),
2326
1
            (
2327
1
                Arc::new(Column::new_with_schema("b2", &left.schema())
?0
) as _,
2328
1
                Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2329
1
            ),
2330
1
        ];
2331
1
        let (_, batches) = join_collect_with_options(
2332
1
            left,
2333
1
            right,
2334
1
            on,
2335
1
            JoinType::Inner,
2336
1
            vec![
2337
1
                SortOptions {
2338
1
                    descending: true,
2339
1
                    nulls_first: false,
2340
1
                };
2341
1
                2
2342
1
            ],
2343
1
            true,
2344
1
        )
2345
1
        .
await0
?0
;
2346
1
        let expected = [
2347
1
            "+----+----+----+----+----+----+",
2348
1
            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2349
1
            "+----+----+----+----+----+----+",
2350
1
            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2351
1
            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2352
1
            "| 1  | 1  |    | 1  | 1  | 70 |",
2353
1
            "| 1  |    | 1  | 1  |    | 10 |",
2354
1
            "+----+----+----+----+----+----+",
2355
1
        ];
2356
1
        // The output order is important as SMJ preserves sortedness
2357
1
        assert_batches_eq!(expected, &batches);
2358
1
        Ok(())
2359
1
    }
2360
2361
    #[tokio::test]
2362
1
    async fn join_inner_output_two_batches() -> Result<()> {
2363
1
        let left = build_table(
2364
1
            ("a1", &vec![1, 2, 2]),
2365
1
            ("b2", &vec![1, 2, 2]),
2366
1
            ("c1", &vec![7, 8, 9]),
2367
1
        );
2368
1
        let right = build_table(
2369
1
            ("a1", &vec![1, 2, 3]),
2370
1
            ("b2", &vec![1, 2, 2]),
2371
1
            ("c2", &vec![70, 80, 90]),
2372
1
        );
2373
1
        let on = vec![
2374
1
            (
2375
1
                Arc::new(Column::new_with_schema("a1", &left.schema())
?0
) as _,
2376
1
                Arc::new(Column::new_with_schema("a1", &right.schema())
?0
) as _,
2377
1
            ),
2378
1
            (
2379
1
                Arc::new(Column::new_with_schema("b2", &left.schema())
?0
) as _,
2380
1
                Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2381
1
            ),
2382
1
        ];
2383
1
2384
1
        let (_, batches) =
2385
1
            join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).
await0
?0
;
2386
1
        let expected = [
2387
1
            "+----+----+----+----+----+----+",
2388
1
            "| a1 | b2 | c1 | a1 | b2 | c2 |",
2389
1
            "+----+----+----+----+----+----+",
2390
1
            "| 1  | 1  | 7  | 1  | 1  | 70 |",
2391
1
            "| 2  | 2  | 8  | 2  | 2  | 80 |",
2392
1
            "| 2  | 2  | 9  | 2  | 2  | 80 |",
2393
1
            "+----+----+----+----+----+----+",
2394
1
        ];
2395
1
        assert_eq!(batches.len(), 2);
2396
1
        assert_eq!(batches[0].num_rows(), 2);
2397
1
        assert_eq!(batches[1].num_rows(), 1);
2398
1
        // The output order is important as SMJ preserves sortedness
2399
1
        assert_batches_eq!(expected, &batches);
2400
1
        Ok(())
2401
1
    }
2402
2403
    #[tokio::test]
2404
1
    async fn join_left_one() -> Result<()> {
2405
1
        let left = build_table(
2406
1
            ("a1", &vec![1, 2, 3]),
2407
1
            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
2408
1
            ("c1", &vec![7, 8, 9]),
2409
1
        );
2410
1
        let right = build_table(
2411
1
            ("a2", &vec![10, 20, 30]),
2412
1
            ("b1", &vec![4, 5, 6]),
2413
1
            ("c2", &vec![70, 80, 90]),
2414
1
        );
2415
1
        let on = vec![(
2416
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2417
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2418
1
        )];
2419
1
2420
1
        let (_, batches) = join_collect(left, right, on, JoinType::Left).
await0
?0
;
2421
1
        let expected = [
2422
1
            "+----+----+----+----+----+----+",
2423
1
            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2424
1
            "+----+----+----+----+----+----+",
2425
1
            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2426
1
            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2427
1
            "| 3  | 7  | 9  |    |    |    |",
2428
1
            "+----+----+----+----+----+----+",
2429
1
        ];
2430
1
        // The output order is important as SMJ preserves sortedness
2431
1
        assert_batches_eq!(expected, &batches);
2432
1
        Ok(())
2433
1
    }
2434
2435
    #[tokio::test]
2436
1
    async fn join_right_one() -> Result<()> {
2437
1
        let left = build_table(
2438
1
            ("a1", &vec![1, 2, 3]),
2439
1
            ("b1", &vec![4, 5, 7]),
2440
1
            ("c1", &vec![7, 8, 9]),
2441
1
        );
2442
1
        let right = build_table(
2443
1
            ("a2", &vec![10, 20, 30]),
2444
1
            ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
2445
1
            ("c2", &vec![70, 80, 90]),
2446
1
        );
2447
1
        let on = vec![(
2448
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2449
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2450
1
        )];
2451
1
2452
1
        let (_, batches) = join_collect(left, right, on, JoinType::Right).
await0
?0
;
2453
1
        let expected = [
2454
1
            "+----+----+----+----+----+----+",
2455
1
            "| a1 | b1 | c1 | a2 | b1 | c2 |",
2456
1
            "+----+----+----+----+----+----+",
2457
1
            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2458
1
            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2459
1
            "|    |    |    | 30 | 6  | 90 |",
2460
1
            "+----+----+----+----+----+----+",
2461
1
        ];
2462
1
        // The output order is important as SMJ preserves sortedness
2463
1
        assert_batches_eq!(expected, &batches);
2464
1
        Ok(())
2465
1
    }
2466
2467
    #[tokio::test]
2468
1
    async fn join_full_one() -> Result<()> {
2469
1
        let left = build_table(
2470
1
            ("a1", &vec![1, 2, 3]),
2471
1
            ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
2472
1
            ("c1", &vec![7, 8, 9]),
2473
1
        );
2474
1
        let right = build_table(
2475
1
            ("a2", &vec![10, 20, 30]),
2476
1
            ("b2", &vec![4, 5, 6]),
2477
1
            ("c2", &vec![70, 80, 90]),
2478
1
        );
2479
1
        let on = vec![(
2480
1
            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
2481
1
            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
2482
1
        )];
2483
1
2484
1
        let (_, batches) = join_collect(left, right, on, JoinType::Full).
await0
?0
;
2485
1
        let expected = [
2486
1
            "+----+----+----+----+----+----+",
2487
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2488
1
            "+----+----+----+----+----+----+",
2489
1
            "|    |    |    | 30 | 6  | 90 |",
2490
1
            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2491
1
            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2492
1
            "| 3  | 7  | 9  |    |    |    |",
2493
1
            "+----+----+----+----+----+----+",
2494
1
        ];
2495
1
        assert_batches_sorted_eq!(expected, &batches);
2496
1
        Ok(())
2497
1
    }
2498
2499
    #[tokio::test]
2500
1
    async fn join_anti() -> Result<()> {
2501
1
        let left = build_table(
2502
1
            ("a1", &vec![1, 2, 2, 3, 5]),
2503
1
            ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
2504
1
            ("c1", &vec![7, 8, 8, 9, 11]),
2505
1
        );
2506
1
        let right = build_table(
2507
1
            ("a2", &vec![10, 20, 30]),
2508
1
            ("b1", &vec![4, 5, 6]),
2509
1
            ("c2", &vec![70, 80, 90]),
2510
1
        );
2511
1
        let on = vec![(
2512
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2513
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2514
1
        )];
2515
1
2516
1
        let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).
await0
?0
;
2517
1
        let expected = [
2518
1
            "+----+----+----+",
2519
1
            "| a1 | b1 | c1 |",
2520
1
            "+----+----+----+",
2521
1
            "| 3  | 7  | 9  |",
2522
1
            "| 5  | 7  | 11 |",
2523
1
            "+----+----+----+",
2524
1
        ];
2525
1
        // The output order is important as SMJ preserves sortedness
2526
1
        assert_batches_eq!(expected, &batches);
2527
1
        Ok(())
2528
1
    }
2529
2530
    #[tokio::test]
2531
1
    async fn join_semi() -> Result<()> {
2532
1
        let left = build_table(
2533
1
            ("a1", &vec![1, 2, 2, 3]),
2534
1
            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
2535
1
            ("c1", &vec![7, 8, 8, 9]),
2536
1
        );
2537
1
        let right = build_table(
2538
1
            ("a2", &vec![10, 20, 30]),
2539
1
            ("b1", &vec![4, 5, 6]), // 5 is double on the right
2540
1
            ("c2", &vec![70, 80, 90]),
2541
1
        );
2542
1
        let on = vec![(
2543
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2544
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2545
1
        )];
2546
1
2547
1
        let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).
await0
?0
;
2548
1
        let expected = [
2549
1
            "+----+----+----+",
2550
1
            "| a1 | b1 | c1 |",
2551
1
            "+----+----+----+",
2552
1
            "| 1  | 4  | 7  |",
2553
1
            "| 2  | 5  | 8  |",
2554
1
            "| 2  | 5  | 8  |",
2555
1
            "+----+----+----+",
2556
1
        ];
2557
1
        // The output order is important as SMJ preserves sortedness
2558
1
        assert_batches_eq!(expected, &batches);
2559
1
        Ok(())
2560
1
    }
2561
2562
    #[tokio::test]
2563
1
    async fn join_with_duplicated_column_names() -> Result<()> {
2564
1
        let left = build_table(
2565
1
            ("a", &vec![1, 2, 3]),
2566
1
            ("b", &vec![4, 5, 7]),
2567
1
            ("c", &vec![7, 8, 9]),
2568
1
        );
2569
1
        let right = build_table(
2570
1
            ("a", &vec![10, 20, 30]),
2571
1
            ("b", &vec![1, 2, 7]),
2572
1
            ("c", &vec![70, 80, 90]),
2573
1
        );
2574
1
        let on = vec![(
2575
1
            // join on a=b so there are duplicate column names on unjoined columns
2576
1
            Arc::new(Column::new_with_schema("a", &left.schema())
?0
) as _,
2577
1
            Arc::new(Column::new_with_schema("b", &right.schema())
?0
) as _,
2578
1
        )];
2579
1
2580
1
        let (_, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2581
1
        let expected = [
2582
1
            "+---+---+---+----+---+----+",
2583
1
            "| a | b | c | a  | b | c  |",
2584
1
            "+---+---+---+----+---+----+",
2585
1
            "| 1 | 4 | 7 | 10 | 1 | 70 |",
2586
1
            "| 2 | 5 | 8 | 20 | 2 | 80 |",
2587
1
            "+---+---+---+----+---+----+",
2588
1
        ];
2589
1
        // The output order is important as SMJ preserves sortedness
2590
1
        assert_batches_eq!(expected, &batches);
2591
1
        Ok(())
2592
1
    }
2593
2594
    #[tokio::test]
2595
1
    async fn join_date32() -> Result<()> {
2596
1
        let left = build_date_table(
2597
1
            ("a1", &vec![1, 2, 3]),
2598
1
            ("b1", &vec![19107, 19108, 19108]), // this has a repetition
2599
1
            ("c1", &vec![7, 8, 9]),
2600
1
        );
2601
1
        let right = build_date_table(
2602
1
            ("a2", &vec![10, 20, 30]),
2603
1
            ("b1", &vec![19107, 19108, 19109]),
2604
1
            ("c2", &vec![70, 80, 90]),
2605
1
        );
2606
1
2607
1
        let on = vec![(
2608
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2609
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2610
1
        )];
2611
1
2612
1
        let (_, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2613
1
2614
1
        let expected = ["+------------+------------+------------+------------+------------+------------+",
2615
1
            "| a1         | b1         | c1         | a2         | b1         | c2         |",
2616
1
            "+------------+------------+------------+------------+------------+------------+",
2617
1
            "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |",
2618
1
            "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
2619
1
            "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |",
2620
1
            "+------------+------------+------------+------------+------------+------------+"];
2621
1
        // The output order is important as SMJ preserves sortedness
2622
1
        assert_batches_eq!(expected, &batches);
2623
1
        Ok(())
2624
1
    }
2625
2626
    #[tokio::test]
2627
1
    async fn join_date64() -> Result<()> {
2628
1
        let left = build_date64_table(
2629
1
            ("a1", &vec![1, 2, 3]),
2630
1
            ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition
2631
1
            ("c1", &vec![7, 8, 9]),
2632
1
        );
2633
1
        let right = build_date64_table(
2634
1
            ("a2", &vec![10, 20, 30]),
2635
1
            ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
2636
1
            ("c2", &vec![70, 80, 90]),
2637
1
        );
2638
1
2639
1
        let on = vec![(
2640
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2641
1
            Arc::new(Column::new_with_schema("b1", &right.schema())
?0
) as _,
2642
1
        )];
2643
1
2644
1
        let (_, batches) = join_collect(left, right, on, JoinType::Inner).
await0
?0
;
2645
1
2646
1
        let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
2647
1
            "| a1                      | b1                  | c1                      | a2                      | b1                  | c2                      |",
2648
1
            "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+",
2649
1
            "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |",
2650
1
            "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
2651
1
            "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |",
2652
1
            "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+"];
2653
1
        // The output order is important as SMJ preserves sortedness
2654
1
        assert_batches_eq!(expected, &batches);
2655
1
        Ok(())
2656
1
    }
2657
2658
    #[tokio::test]
2659
1
    async fn join_left_sort_order() -> Result<()> {
2660
1
        let left = build_table(
2661
1
            ("a1", &vec![0, 1, 2, 3, 4, 5]),
2662
1
            ("b1", &vec![3, 4, 5, 6, 6, 7]),
2663
1
            ("c1", &vec![4, 5, 6, 7, 8, 9]),
2664
1
        );
2665
1
        let right = build_table(
2666
1
            ("a2", &vec![0, 10, 20, 30, 40]),
2667
1
            ("b2", &vec![2, 4, 6, 6, 8]),
2668
1
            ("c2", &vec![50, 60, 70, 80, 90]),
2669
1
        );
2670
1
        let on = vec![(
2671
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2672
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2673
1
        )];
2674
1
2675
1
        let (_, batches) = join_collect(left, right, on, JoinType::Left).
await0
?0
;
2676
1
        let expected = [
2677
1
            "+----+----+----+----+----+----+",
2678
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2679
1
            "+----+----+----+----+----+----+",
2680
1
            "| 0  | 3  | 4  |    |    |    |",
2681
1
            "| 1  | 4  | 5  | 10 | 4  | 60 |",
2682
1
            "| 2  | 5  | 6  |    |    |    |",
2683
1
            "| 3  | 6  | 7  | 20 | 6  | 70 |",
2684
1
            "| 3  | 6  | 7  | 30 | 6  | 80 |",
2685
1
            "| 4  | 6  | 8  | 20 | 6  | 70 |",
2686
1
            "| 4  | 6  | 8  | 30 | 6  | 80 |",
2687
1
            "| 5  | 7  | 9  |    |    |    |",
2688
1
            "+----+----+----+----+----+----+",
2689
1
        ];
2690
1
        assert_batches_eq!(expected, &batches);
2691
1
        Ok(())
2692
1
    }
2693
2694
    #[tokio::test]
2695
1
    async fn join_right_sort_order() -> Result<()> {
2696
1
        let left = build_table(
2697
1
            ("a1", &vec![0, 1, 2, 3]),
2698
1
            ("b1", &vec![3, 4, 5, 7]),
2699
1
            ("c1", &vec![6, 7, 8, 9]),
2700
1
        );
2701
1
        let right = build_table(
2702
1
            ("a2", &vec![0, 10, 20, 30]),
2703
1
            ("b2", &vec![2, 4, 5, 6]),
2704
1
            ("c2", &vec![60, 70, 80, 90]),
2705
1
        );
2706
1
        let on = vec![(
2707
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2708
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2709
1
        )];
2710
1
2711
1
        let (_, batches) = join_collect(left, right, on, JoinType::Right).
await0
?0
;
2712
1
        let expected = [
2713
1
            "+----+----+----+----+----+----+",
2714
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2715
1
            "+----+----+----+----+----+----+",
2716
1
            "|    |    |    | 0  | 2  | 60 |",
2717
1
            "| 1  | 4  | 7  | 10 | 4  | 70 |",
2718
1
            "| 2  | 5  | 8  | 20 | 5  | 80 |",
2719
1
            "|    |    |    | 30 | 6  | 90 |",
2720
1
            "+----+----+----+----+----+----+",
2721
1
        ];
2722
1
        assert_batches_eq!(expected, &batches);
2723
1
        Ok(())
2724
1
    }
2725
2726
    #[tokio::test]
2727
1
    async fn join_left_multiple_batches() -> Result<()> {
2728
1
        let left_batch_1 = build_table_i32(
2729
1
            ("a1", &vec![0, 1, 2]),
2730
1
            ("b1", &vec![3, 4, 5]),
2731
1
            ("c1", &vec![4, 5, 6]),
2732
1
        );
2733
1
        let left_batch_2 = build_table_i32(
2734
1
            ("a1", &vec![3, 4, 5, 6]),
2735
1
            ("b1", &vec![6, 6, 7, 9]),
2736
1
            ("c1", &vec![7, 8, 9, 9]),
2737
1
        );
2738
1
        let right_batch_1 = build_table_i32(
2739
1
            ("a2", &vec![0, 10, 20]),
2740
1
            ("b2", &vec![2, 4, 6]),
2741
1
            ("c2", &vec![50, 60, 70]),
2742
1
        );
2743
1
        let right_batch_2 = build_table_i32(
2744
1
            ("a2", &vec![30, 40]),
2745
1
            ("b2", &vec![6, 8]),
2746
1
            ("c2", &vec![80, 90]),
2747
1
        );
2748
1
        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
2749
1
        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
2750
1
        let on = vec![(
2751
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2752
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2753
1
        )];
2754
1
2755
1
        let (_, batches) = join_collect(left, right, on, JoinType::Left).
await0
?0
;
2756
1
        let expected = vec![
2757
1
            "+----+----+----+----+----+----+",
2758
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2759
1
            "+----+----+----+----+----+----+",
2760
1
            "| 0  | 3  | 4  |    |    |    |",
2761
1
            "| 1  | 4  | 5  | 10 | 4  | 60 |",
2762
1
            "| 2  | 5  | 6  |    |    |    |",
2763
1
            "| 3  | 6  | 7  | 20 | 6  | 70 |",
2764
1
            "| 3  | 6  | 7  | 30 | 6  | 80 |",
2765
1
            "| 4  | 6  | 8  | 20 | 6  | 70 |",
2766
1
            "| 4  | 6  | 8  | 30 | 6  | 80 |",
2767
1
            "| 5  | 7  | 9  |    |    |    |",
2768
1
            "| 6  | 9  | 9  |    |    |    |",
2769
1
            "+----+----+----+----+----+----+",
2770
1
        ];
2771
1
        assert_batches_eq!(expected, &batches);
2772
1
        Ok(())
2773
1
    }
2774
2775
    #[tokio::test]
2776
1
    async fn join_right_multiple_batches() -> Result<()> {
2777
1
        let right_batch_1 = build_table_i32(
2778
1
            ("a2", &vec![0, 1, 2]),
2779
1
            ("b2", &vec![3, 4, 5]),
2780
1
            ("c2", &vec![4, 5, 6]),
2781
1
        );
2782
1
        let right_batch_2 = build_table_i32(
2783
1
            ("a2", &vec![3, 4, 5, 6]),
2784
1
            ("b2", &vec![6, 6, 7, 9]),
2785
1
            ("c2", &vec![7, 8, 9, 9]),
2786
1
        );
2787
1
        let left_batch_1 = build_table_i32(
2788
1
            ("a1", &vec![0, 10, 20]),
2789
1
            ("b1", &vec![2, 4, 6]),
2790
1
            ("c1", &vec![50, 60, 70]),
2791
1
        );
2792
1
        let left_batch_2 = build_table_i32(
2793
1
            ("a1", &vec![30, 40]),
2794
1
            ("b1", &vec![6, 8]),
2795
1
            ("c1", &vec![80, 90]),
2796
1
        );
2797
1
        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
2798
1
        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
2799
1
        let on = vec![(
2800
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2801
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2802
1
        )];
2803
1
2804
1
        let (_, batches) = join_collect(left, right, on, JoinType::Right).
await0
?0
;
2805
1
        let expected = vec![
2806
1
            "+----+----+----+----+----+----+",
2807
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2808
1
            "+----+----+----+----+----+----+",
2809
1
            "|    |    |    | 0  | 3  | 4  |",
2810
1
            "| 10 | 4  | 60 | 1  | 4  | 5  |",
2811
1
            "|    |    |    | 2  | 5  | 6  |",
2812
1
            "| 20 | 6  | 70 | 3  | 6  | 7  |",
2813
1
            "| 30 | 6  | 80 | 3  | 6  | 7  |",
2814
1
            "| 20 | 6  | 70 | 4  | 6  | 8  |",
2815
1
            "| 30 | 6  | 80 | 4  | 6  | 8  |",
2816
1
            "|    |    |    | 5  | 7  | 9  |",
2817
1
            "|    |    |    | 6  | 9  | 9  |",
2818
1
            "+----+----+----+----+----+----+",
2819
1
        ];
2820
1
        assert_batches_eq!(expected, &batches);
2821
1
        Ok(())
2822
1
    }
2823
2824
    #[tokio::test]
2825
1
    async fn join_full_multiple_batches() -> Result<()> {
2826
1
        let left_batch_1 = build_table_i32(
2827
1
            ("a1", &vec![0, 1, 2]),
2828
1
            ("b1", &vec![3, 4, 5]),
2829
1
            ("c1", &vec![4, 5, 6]),
2830
1
        );
2831
1
        let left_batch_2 = build_table_i32(
2832
1
            ("a1", &vec![3, 4, 5, 6]),
2833
1
            ("b1", &vec![6, 6, 7, 9]),
2834
1
            ("c1", &vec![7, 8, 9, 9]),
2835
1
        );
2836
1
        let right_batch_1 = build_table_i32(
2837
1
            ("a2", &vec![0, 10, 20]),
2838
1
            ("b2", &vec![2, 4, 6]),
2839
1
            ("c2", &vec![50, 60, 70]),
2840
1
        );
2841
1
        let right_batch_2 = build_table_i32(
2842
1
            ("a2", &vec![30, 40]),
2843
1
            ("b2", &vec![6, 8]),
2844
1
            ("c2", &vec![80, 90]),
2845
1
        );
2846
1
        let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
2847
1
        let right = build_table_from_batches(vec![right_batch_1, right_batch_2]);
2848
1
        let on = vec![(
2849
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2850
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2851
1
        )];
2852
1
2853
1
        let (_, batches) = join_collect(left, right, on, JoinType::Full).
await0
?0
;
2854
1
        let expected = vec![
2855
1
            "+----+----+----+----+----+----+",
2856
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
2857
1
            "+----+----+----+----+----+----+",
2858
1
            "|    |    |    | 0  | 2  | 50 |",
2859
1
            "|    |    |    | 40 | 8  | 90 |",
2860
1
            "| 0  | 3  | 4  |    |    |    |",
2861
1
            "| 1  | 4  | 5  | 10 | 4  | 60 |",
2862
1
            "| 2  | 5  | 6  |    |    |    |",
2863
1
            "| 3  | 6  | 7  | 20 | 6  | 70 |",
2864
1
            "| 3  | 6  | 7  | 30 | 6  | 80 |",
2865
1
            "| 4  | 6  | 8  | 20 | 6  | 70 |",
2866
1
            "| 4  | 6  | 8  | 30 | 6  | 80 |",
2867
1
            "| 5  | 7  | 9  |    |    |    |",
2868
1
            "| 6  | 9  | 9  |    |    |    |",
2869
1
            "+----+----+----+----+----+----+",
2870
1
        ];
2871
1
        assert_batches_sorted_eq!(expected, &batches);
2872
1
        Ok(())
2873
1
    }
2874
2875
    #[tokio::test]
2876
1
    async fn overallocation_single_batch_no_spill() -> Result<()> {
2877
1
        let left = build_table(
2878
1
            ("a1", &vec![0, 1, 2, 3, 4, 5]),
2879
1
            ("b1", &vec![1, 2, 3, 4, 5, 6]),
2880
1
            ("c1", &vec![4, 5, 6, 7, 8, 9]),
2881
1
        );
2882
1
        let right = build_table(
2883
1
            ("a2", &vec![0, 10, 20, 30, 40]),
2884
1
            ("b2", &vec![1, 3, 4, 6, 8]),
2885
1
            ("c2", &vec![50, 60, 70, 80, 90]),
2886
1
        );
2887
1
        let on = vec![(
2888
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2889
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2890
1
        )];
2891
1
        let sort_options = vec![SortOptions::default(); on.len()];
2892
1
2893
1
        let join_types = vec![
2894
1
            JoinType::Inner,
2895
1
            JoinType::Left,
2896
1
            JoinType::Right,
2897
1
            JoinType::Full,
2898
1
            JoinType::LeftSemi,
2899
1
            JoinType::LeftAnti,
2900
1
        ];
2901
1
2902
1
        // Disable DiskManager to prevent spilling
2903
1
        let runtime = RuntimeEnvBuilder::new()
2904
1
            .with_memory_limit(100, 1.0)
2905
1
            .with_disk_manager(DiskManagerConfig::Disabled)
2906
1
            .build_arc()
?0
;
2907
1
        let session_config = SessionConfig::default().with_batch_size(50);
2908
1
2909
7
        for 
join_type6
in join_types {
2910
6
            let task_ctx = TaskContext::default()
2911
6
                .with_session_config(session_config.clone())
2912
6
                .with_runtime(Arc::clone(&runtime));
2913
6
            let task_ctx = Arc::new(task_ctx);
2914
1
2915
6
            let join = join_with_options(
2916
6
                Arc::clone(&left),
2917
6
                Arc::clone(&right),
2918
6
                on.clone(),
2919
6
                join_type,
2920
6
                sort_options.clone(),
2921
6
                false,
2922
6
            )
?0
;
2923
1
2924
6
            let stream = join.execute(0, task_ctx)
?0
;
2925
6
            let err = common::collect(stream).
await0
.unwrap_err();
2926
6
2927
6
            assert_contains!(err.to_string(), "Failed to allocate additional");
2928
6
            assert_contains!(err.to_string(), "SMJStream[0]");
2929
6
            assert_contains!(err.to_string(), "Disk spilling disabled");
2930
6
            assert!(join.metrics().is_some());
2931
6
            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
2932
6
            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
2933
6
            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
2934
1
        }
2935
1
2936
1
        Ok(())
2937
1
    }
2938
2939
    #[tokio::test]
2940
1
    async fn overallocation_multi_batch_no_spill() -> Result<()> {
2941
1
        let left_batch_1 = build_table_i32(
2942
1
            ("a1", &vec![0, 1]),
2943
1
            ("b1", &vec![1, 1]),
2944
1
            ("c1", &vec![4, 5]),
2945
1
        );
2946
1
        let left_batch_2 = build_table_i32(
2947
1
            ("a1", &vec![2, 3]),
2948
1
            ("b1", &vec![1, 1]),
2949
1
            ("c1", &vec![6, 7]),
2950
1
        );
2951
1
        let left_batch_3 = build_table_i32(
2952
1
            ("a1", &vec![4, 5]),
2953
1
            ("b1", &vec![1, 1]),
2954
1
            ("c1", &vec![8, 9]),
2955
1
        );
2956
1
        let right_batch_1 = build_table_i32(
2957
1
            ("a2", &vec![0, 10]),
2958
1
            ("b2", &vec![1, 1]),
2959
1
            ("c2", &vec![50, 60]),
2960
1
        );
2961
1
        let right_batch_2 = build_table_i32(
2962
1
            ("a2", &vec![20, 30]),
2963
1
            ("b2", &vec![1, 1]),
2964
1
            ("c2", &vec![70, 80]),
2965
1
        );
2966
1
        let right_batch_3 =
2967
1
            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
2968
1
        let left =
2969
1
            build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
2970
1
        let right =
2971
1
            build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
2972
1
        let on = vec![(
2973
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
2974
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
2975
1
        )];
2976
1
        let sort_options = vec![SortOptions::default(); on.len()];
2977
1
2978
1
        let join_types = vec![
2979
1
            JoinType::Inner,
2980
1
            JoinType::Left,
2981
1
            JoinType::Right,
2982
1
            JoinType::Full,
2983
1
            JoinType::LeftSemi,
2984
1
            JoinType::LeftAnti,
2985
1
        ];
2986
1
2987
1
        // Disable DiskManager to prevent spilling
2988
1
        let runtime = RuntimeEnvBuilder::new()
2989
1
            .with_memory_limit(100, 1.0)
2990
1
            .with_disk_manager(DiskManagerConfig::Disabled)
2991
1
            .build_arc()
?0
;
2992
1
        let session_config = SessionConfig::default().with_batch_size(50);
2993
1
2994
7
        for 
join_type6
in join_types {
2995
6
            let task_ctx = TaskContext::default()
2996
6
                .with_session_config(session_config.clone())
2997
6
                .with_runtime(Arc::clone(&runtime));
2998
6
            let task_ctx = Arc::new(task_ctx);
2999
6
            let join = join_with_options(
3000
6
                Arc::clone(&left),
3001
6
                Arc::clone(&right),
3002
6
                on.clone(),
3003
6
                join_type,
3004
6
                sort_options.clone(),
3005
6
                false,
3006
6
            )
?0
;
3007
1
3008
6
            let stream = join.execute(0, task_ctx)
?0
;
3009
6
            let err = common::collect(stream).
await0
.unwrap_err();
3010
6
3011
6
            assert_contains!(err.to_string(), "Failed to allocate additional");
3012
6
            assert_contains!(err.to_string(), "SMJStream[0]");
3013
6
            assert_contains!(err.to_string(), "Disk spilling disabled");
3014
6
            assert!(join.metrics().is_some());
3015
6
            assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3016
6
            assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3017
6
            assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3018
1
        }
3019
1
3020
1
        Ok(())
3021
1
    }
3022
3023
    #[tokio::test]
3024
1
    async fn overallocation_single_batch_spill() -> Result<()> {
3025
1
        let left = build_table(
3026
1
            ("a1", &vec![0, 1, 2, 3, 4, 5]),
3027
1
            ("b1", &vec![1, 2, 3, 4, 5, 6]),
3028
1
            ("c1", &vec![4, 5, 6, 7, 8, 9]),
3029
1
        );
3030
1
        let right = build_table(
3031
1
            ("a2", &vec![0, 10, 20, 30, 40]),
3032
1
            ("b2", &vec![1, 3, 4, 6, 8]),
3033
1
            ("c2", &vec![50, 60, 70, 80, 90]),
3034
1
        );
3035
1
        let on = vec![(
3036
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
3037
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
3038
1
        )];
3039
1
        let sort_options = vec![SortOptions::default(); on.len()];
3040
1
3041
1
        let join_types = [
3042
1
            JoinType::Inner,
3043
1
            JoinType::Left,
3044
1
            JoinType::Right,
3045
1
            JoinType::Full,
3046
1
            JoinType::LeftSemi,
3047
1
            JoinType::LeftAnti,
3048
1
        ];
3049
1
3050
1
        // Enable DiskManager to allow spilling
3051
1
        let runtime = RuntimeEnvBuilder::new()
3052
1
            .with_memory_limit(100, 1.0)
3053
1
            .with_disk_manager(DiskManagerConfig::NewOs)
3054
1
            .build_arc()
?0
;
3055
1
3056
3
        for 
batch_size2
in [1, 50] {
3057
2
            let session_config = SessionConfig::default().with_batch_size(batch_size);
3058
1
3059
14
            for 
join_type12
in &join_types {
3060
12
                let task_ctx = TaskContext::default()
3061
12
                    .with_session_config(session_config.clone())
3062
12
                    .with_runtime(Arc::clone(&runtime));
3063
12
                let task_ctx = Arc::new(task_ctx);
3064
1
3065
12
                let join = join_with_options(
3066
12
                    Arc::clone(&left),
3067
12
                    Arc::clone(&right),
3068
12
                    on.clone(),
3069
12
                    *join_type,
3070
12
                    sort_options.clone(),
3071
12
                    false,
3072
12
                )
?0
;
3073
1
3074
12
                let stream = join.execute(0, task_ctx)
?0
;
3075
12
                let spilled_join_result = common::collect(stream).
await0
.unwrap();
3076
12
3077
12
                assert!(join.metrics().is_some());
3078
12
                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
3079
12
                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
3080
12
                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
3081
1
3082
1
                // Run the test with no spill configuration as
3083
12
                let task_ctx_no_spill =
3084
12
                    TaskContext::default().with_session_config(session_config.clone());
3085
12
                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
3086
1
3087
12
                let join = join_with_options(
3088
12
                    Arc::clone(&left),
3089
12
                    Arc::clone(&right),
3090
12
                    on.clone(),
3091
12
                    *join_type,
3092
12
                    sort_options.clone(),
3093
12
                    false,
3094
12
                )
?0
;
3095
12
                let stream = join.execute(0, task_ctx_no_spill)
?0
;
3096
12
                let no_spilled_join_result = common::collect(stream).
await0
.unwrap();
3097
12
3098
12
                assert!(join.metrics().is_some());
3099
12
                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3100
12
                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3101
12
                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3102
1
                // Compare spilled and non spilled data to check spill logic doesn't corrupt the data
3103
12
                assert_eq!(spilled_join_result, no_spilled_join_result);
3104
1
            }
3105
1
        }
3106
1
3107
1
        Ok(())
3108
1
    }
3109
3110
    #[tokio::test]
3111
1
    async fn overallocation_multi_batch_spill() -> Result<()> {
3112
1
        let left_batch_1 = build_table_i32(
3113
1
            ("a1", &vec![0, 1]),
3114
1
            ("b1", &vec![1, 1]),
3115
1
            ("c1", &vec![4, 5]),
3116
1
        );
3117
1
        let left_batch_2 = build_table_i32(
3118
1
            ("a1", &vec![2, 3]),
3119
1
            ("b1", &vec![1, 1]),
3120
1
            ("c1", &vec![6, 7]),
3121
1
        );
3122
1
        let left_batch_3 = build_table_i32(
3123
1
            ("a1", &vec![4, 5]),
3124
1
            ("b1", &vec![1, 1]),
3125
1
            ("c1", &vec![8, 9]),
3126
1
        );
3127
1
        let right_batch_1 = build_table_i32(
3128
1
            ("a2", &vec![0, 10]),
3129
1
            ("b2", &vec![1, 1]),
3130
1
            ("c2", &vec![50, 60]),
3131
1
        );
3132
1
        let right_batch_2 = build_table_i32(
3133
1
            ("a2", &vec![20, 30]),
3134
1
            ("b2", &vec![1, 1]),
3135
1
            ("c2", &vec![70, 80]),
3136
1
        );
3137
1
        let right_batch_3 =
3138
1
            build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
3139
1
        let left =
3140
1
            build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
3141
1
        let right =
3142
1
            build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
3143
1
        let on = vec![(
3144
1
            Arc::new(Column::new_with_schema("b1", &left.schema())
?0
) as _,
3145
1
            Arc::new(Column::new_with_schema("b2", &right.schema())
?0
) as _,
3146
1
        )];
3147
1
        let sort_options = vec![SortOptions::default(); on.len()];
3148
1
3149
1
        let join_types = [
3150
1
            JoinType::Inner,
3151
1
            JoinType::Left,
3152
1
            JoinType::Right,
3153
1
            JoinType::Full,
3154
1
            JoinType::LeftSemi,
3155
1
            JoinType::LeftAnti,
3156
1
        ];
3157
1
3158
1
        // Enable DiskManager to allow spilling
3159
1
        let runtime = RuntimeEnvBuilder::new()
3160
1
            .with_memory_limit(500, 1.0)
3161
1
            .with_disk_manager(DiskManagerConfig::NewOs)
3162
1
            .build_arc()
?0
;
3163
1
3164
3
        for 
batch_size2
in [1, 50] {
3165
2
            let session_config = SessionConfig::default().with_batch_size(batch_size);
3166
1
3167
14
            for 
join_type12
in &join_types {
3168
12
                let task_ctx = TaskContext::default()
3169
12
                    .with_session_config(session_config.clone())
3170
12
                    .with_runtime(Arc::clone(&runtime));
3171
12
                let task_ctx = Arc::new(task_ctx);
3172
12
                let join = join_with_options(
3173
12
                    Arc::clone(&left),
3174
12
                    Arc::clone(&right),
3175
12
                    on.clone(),
3176
12
                    *join_type,
3177
12
                    sort_options.clone(),
3178
12
                    false,
3179
12
                )
?0
;
3180
1
3181
12
                let stream = join.execute(0, task_ctx)
?0
;
3182
12
                let spilled_join_result = common::collect(stream).
await0
.unwrap();
3183
12
                assert!(join.metrics().is_some());
3184
12
                assert!(join.metrics().unwrap().spill_count().unwrap() > 0);
3185
12
                assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0);
3186
12
                assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0);
3187
1
3188
1
                // Run the test with no spill configuration as
3189
12
                let task_ctx_no_spill =
3190
12
                    TaskContext::default().with_session_config(session_config.clone());
3191
12
                let task_ctx_no_spill = Arc::new(task_ctx_no_spill);
3192
1
3193
12
                let join = join_with_options(
3194
12
                    Arc::clone(&left),
3195
12
                    Arc::clone(&right),
3196
12
                    on.clone(),
3197
12
                    *join_type,
3198
12
                    sort_options.clone(),
3199
12
                    false,
3200
12
                )
?0
;
3201
12
                let stream = join.execute(0, task_ctx_no_spill)
?0
;
3202
12
                let no_spilled_join_result = common::collect(stream).
await0
.unwrap();
3203
12
3204
12
                assert!(join.metrics().is_some());
3205
12
                assert_eq!(join.metrics().unwrap().spill_count(), Some(0));
3206
12
                assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0));
3207
12
                assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0));
3208
1
                // Compare spilled and non spilled data to check spill logic doesn't corrupt the data
3209
12
                assert_eq!(spilled_join_result, no_spilled_join_result);
3210
1
            }
3211
1
        }
3212
1
3213
1
        Ok(())
3214
1
    }
3215
3216
    #[tokio::test]
3217
1
    async fn left_semi_join_filtered_mask() -> Result<()> {
3218
1
        assert_eq!(
3219
1
            get_filtered_join_mask(
3220
1
                LeftSemi,
3221
1
                &UInt64Array::from(vec![0, 0, 1, 1]),
3222
1
                &BooleanArray::from(vec![true, true, false, false]),
3223
1
                &HashSet::new(),
3224
1
                &0,
3225
1
            ),
3226
1
            Some((BooleanArray::from(vec![true, false, false, false]), vec![0]))
3227
1
        );
3228
1
3229
1
        assert_eq!(
3230
1
            get_filtered_join_mask(
3231
1
                LeftSemi,
3232
1
                &UInt64Array::from(vec![0, 1]),
3233
1
                &BooleanArray::from(vec![true, true]),
3234
1
                &HashSet::new(),
3235
1
                &0,
3236
1
            ),
3237
1
            Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
3238
1
        );
3239
1
3240
1
        assert_eq!(
3241
1
            get_filtered_join_mask(
3242
1
                LeftSemi,
3243
1
                &UInt64Array::from(vec![0, 1]),
3244
1
                &BooleanArray::from(vec![false, true]),
3245
1
                &HashSet::new(),
3246
1
                &0,
3247
1
            ),
3248
1
            Some((BooleanArray::from(vec![false, true]), vec![1]))
3249
1
        );
3250
1
3251
1
        assert_eq!(
3252
1
            get_filtered_join_mask(
3253
1
                LeftSemi,
3254
1
                &UInt64Array::from(vec![0, 1]),
3255
1
                &BooleanArray::from(vec![true, false]),
3256
1
                &HashSet::new(),
3257
1
                &0,
3258
1
            ),
3259
1
            Some((BooleanArray::from(vec![true, false]), vec![0]))
3260
1
        );
3261
1
3262
1
        assert_eq!(
3263
1
            get_filtered_join_mask(
3264
1
                LeftSemi,
3265
1
                &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
3266
1
                &BooleanArray::from(vec![false, true, true, true, true, true]),
3267
1
                &HashSet::new(),
3268
1
                &0,
3269
1
            ),
3270
1
            Some((
3271
1
                BooleanArray::from(vec![false, true, false, true, false, false]),
3272
1
                vec![0, 1]
3273
1
            ))
3274
1
        );
3275
1
3276
1
        assert_eq!(
3277
1
            get_filtered_join_mask(
3278
1
                LeftSemi,
3279
1
                &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
3280
1
                &BooleanArray::from(vec![false, false, false, false, false, true]),
3281
1
                &HashSet::new(),
3282
1
                &0,
3283
1
            ),
3284
1
            Some((
3285
1
                BooleanArray::from(vec![false, false, false, false, false, true]),
3286
1
                vec![1]
3287
1
            ))
3288
1
        );
3289
1
3290
1
        assert_eq!(
3291
1
            get_filtered_join_mask(
3292
1
                LeftSemi,
3293
1
                &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
3294
1
                &BooleanArray::from(vec![true, false, false, false, false, true]),
3295
1
                &HashSet::from_iter(vec![1]),
3296
1
                &0,
3297
1
            ),
3298
1
            Some((
3299
1
                BooleanArray::from(vec![true, false, false, false, false, false]),
3300
1
                vec![0]
3301
1
            ))
3302
1
        );
3303
1
3304
1
        Ok(())
3305
1
    }
3306
3307
    #[tokio::test]
3308
1
    async fn left_anti_join_filtered_mask() -> Result<()> {
3309
1
        assert_eq!(
3310
1
            get_filtered_join_mask(
3311
1
                LeftAnti,
3312
1
                &UInt64Array::from(vec![0, 0, 1, 1]),
3313
1
                &BooleanArray::from(vec![true, true, false, false]),
3314
1
                &HashSet::new(),
3315
1
                &0,
3316
1
            ),
3317
1
            Some((BooleanArray::from(vec![false, false, false, true]), vec![0]))
3318
1
        );
3319
1
3320
1
        assert_eq!(
3321
1
            get_filtered_join_mask(
3322
1
                LeftAnti,
3323
1
                &UInt64Array::from(vec![0, 1]),
3324
1
                &BooleanArray::from(vec![true, true]),
3325
1
                &HashSet::new(),
3326
1
                &0,
3327
1
            ),
3328
1
            Some((BooleanArray::from(vec![false, false]), vec![0, 1]))
3329
1
        );
3330
1
3331
1
        assert_eq!(
3332
1
            get_filtered_join_mask(
3333
1
                LeftAnti,
3334
1
                &UInt64Array::from(vec![0, 1]),
3335
1
                &BooleanArray::from(vec![false, true]),
3336
1
                &HashSet::new(),
3337
1
                &0,
3338
1
            ),
3339
1
            Some((BooleanArray::from(vec![true, false]), vec![1]))
3340
1
        );
3341
1
3342
1
        assert_eq!(
3343
1
            get_filtered_join_mask(
3344
1
                LeftAnti,
3345
1
                &UInt64Array::from(vec![0, 1]),
3346
1
                &BooleanArray::from(vec![true, false]),
3347
1
                &HashSet::new(),
3348
1
                &0,
3349
1
            ),
3350
1
            Some((BooleanArray::from(vec![false, true]), vec![0]))
3351
1
        );
3352
1
3353
1
        assert_eq!(
3354
1
            get_filtered_join_mask(
3355
1
                LeftAnti,
3356
1
                &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
3357
1
                &BooleanArray::from(vec![false, true, true, true, true, true]),
3358
1
                &HashSet::new(),
3359
1
                &0,
3360
1
            ),
3361
1
            Some((
3362
1
                BooleanArray::from(vec![false, false, false, false, false, false]),
3363
1
                vec![0, 1]
3364
1
            ))
3365
1
        );
3366
1
3367
1
        assert_eq!(
3368
1
            get_filtered_join_mask(
3369
1
                LeftAnti,
3370
1
                &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
3371
1
                &BooleanArray::from(vec![false, false, false, false, false, true]),
3372
1
                &HashSet::new(),
3373
1
                &0,
3374
1
            ),
3375
1
            Some((
3376
1
                BooleanArray::from(vec![false, false, true, false, false, false]),
3377
1
                vec![1]
3378
1
            ))
3379
1
        );
3380
1
3381
1
        Ok(())
3382
1
    }
3383
3384
    /// Returns the column names on the schema
3385
19
    fn columns(schema: &Schema) -> Vec<String> {
3386
108
        schema.fields().iter().map(|f| f.name().clone()).collect()
3387
19
    }
3388
}