Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/nested_loop_join.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines the nested loop join plan, it supports all [`JoinType`].
19
//! The nested loop join can execute in parallel by partitions and it is
20
//! determined by the [`JoinType`].
21
22
use std::any::Any;
23
use std::fmt::Formatter;
24
use std::sync::atomic::{AtomicUsize, Ordering};
25
use std::sync::Arc;
26
use std::task::Poll;
27
28
use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final};
29
use crate::coalesce_partitions::CoalescePartitionsExec;
30
use crate::joins::utils::{
31
    adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices,
32
    build_join_schema, check_join_is_valid, estimate_join_statistics,
33
    get_final_indices_from_bit_map, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
34
    OnceAsync, OnceFut,
35
};
36
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
37
use crate::{
38
    execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution,
39
    ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
40
    RecordBatchStream, SendableRecordBatchStream,
41
};
42
43
use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array};
44
use arrow::compute::concat_batches;
45
use arrow::datatypes::{Schema, SchemaRef};
46
use arrow::record_batch::RecordBatch;
47
use arrow::util::bit_util;
48
use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics};
49
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
50
use datafusion_execution::TaskContext;
51
use datafusion_expr::JoinType;
52
use datafusion_physical_expr::equivalence::join_equivalence_properties;
53
54
use futures::{ready, Stream, StreamExt, TryStreamExt};
55
use parking_lot::Mutex;
56
57
/// Shared bitmap for visited left-side indices
58
type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;
59
/// Left (build-side) data
60
struct JoinLeftData {
61
    /// Build-side data collected to single batch
62
    batch: RecordBatch,
63
    /// Shared bitmap builder for visited left indices
64
    bitmap: SharedBitmapBuilder,
65
    /// Counter of running probe-threads, potentially able to update `bitmap`
66
    probe_threads_counter: AtomicUsize,
67
    /// Memory reservation for tracking batch and bitmap
68
    /// Cleared on `JoinLeftData` drop
69
    #[allow(dead_code)]
70
    reservation: MemoryReservation,
71
}
72
73
impl JoinLeftData {
74
44
    fn new(
75
44
        batch: RecordBatch,
76
44
        bitmap: SharedBitmapBuilder,
77
44
        probe_threads_counter: AtomicUsize,
78
44
        reservation: MemoryReservation,
79
44
    ) -> Self {
80
44
        Self {
81
44
            batch,
82
44
            bitmap,
83
44
            probe_threads_counter,
84
44
            reservation,
85
44
        }
86
44
    }
87
88
12.1k
    fn batch(&self) -> &RecordBatch {
89
12.1k
        &self.batch
90
12.1k
    }
91
92
12.2k
    fn bitmap(&self) -> &SharedBitmapBuilder {
93
12.2k
        &self.bitmap
94
12.2k
    }
95
96
    /// Decrements counter of running threads, and returns `true`
97
    /// if caller is the last running thread
98
16
    fn report_probe_completed(&self) -> bool {
99
16
        self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
100
16
    }
101
}
102
103
/// NestedLoopJoinExec is build-probe join operator, whose main task is to
104
/// perform joins without any equijoin conditions in `ON` clause.
105
///
106
/// Execution consists of following phases:
107
///
108
/// #### 1. Build phase
109
/// Collecting build-side data in memory, by polling all available data from build-side input.
110
/// Due to the absence of equijoin conditions, it's not possible to partition build-side data
111
/// across multiple threads of the operator, so build-side is always collected in a single
112
/// batch shared across all threads.
113
/// The operator always considers LEFT input as build-side input, so it's crucial to adjust
114
/// smaller input to be the LEFT one. Normally this selection is handled by physical optimizer.
115
///
116
/// #### 2. Probe phase
117
/// Sequentially polling batches from the probe-side input and processing them according to the
118
/// following logic:
119
/// - apply join filter (`ON` clause) to Cartesian product of probe batch and build side data
120
///   -- filter evaluation is executed once per build-side data row
121
/// - update shared bitmap of joined ("visited") build-side row indices, if required -- allows
122
///   to produce unmatched build-side data in case of e.g. LEFT/FULL JOIN after probing phase
123
///   completed
124
/// - perform join index alignment is required -- depending on `JoinType`
125
/// - produce output join batch
126
///
127
/// Probing phase is executed in parallel, according to probe-side input partitioning -- one
128
/// thread per partition. After probe input is exhausted, each thread **ATTEMPTS** to produce
129
/// unmatched build-side data.
130
///
131
/// #### 3. Producing unmatched build-side data
132
/// Producing unmatched build-side data as an output batch, after probe input is exhausted.
133
/// This step is also executed in parallel (once per probe input partition), and to avoid
134
/// duplicate output of unmatched data (due to shared nature build-side data), each thread
135
/// "reports" about probe phase completion (which means that "visited" bitmap won't be
136
/// updated anymore), and only the last thread, reporting about completion, will return output.
137
///
138
#[derive(Debug)]
139
pub struct NestedLoopJoinExec {
140
    /// left side
141
    pub(crate) left: Arc<dyn ExecutionPlan>,
142
    /// right side
143
    pub(crate) right: Arc<dyn ExecutionPlan>,
144
    /// Filters which are applied while finding matching rows
145
    pub(crate) filter: Option<JoinFilter>,
146
    /// How the join is performed
147
    pub(crate) join_type: JoinType,
148
    /// The schema once the join is applied
149
    schema: SchemaRef,
150
    /// Build-side data
151
    inner_table: OnceAsync<JoinLeftData>,
152
    /// Information of index and left / right placement of columns
153
    column_indices: Vec<ColumnIndex>,
154
    /// Execution metrics
155
    metrics: ExecutionPlanMetricsSet,
156
    /// Cache holding plan properties like equivalences, output partitioning etc.
157
    cache: PlanProperties,
158
}
159
160
impl NestedLoopJoinExec {
161
    /// Try to create a new [`NestedLoopJoinExec`]
162
52
    pub fn try_new(
163
52
        left: Arc<dyn ExecutionPlan>,
164
52
        right: Arc<dyn ExecutionPlan>,
165
52
        filter: Option<JoinFilter>,
166
52
        join_type: &JoinType,
167
52
    ) -> Result<Self> {
168
52
        let left_schema = left.schema();
169
52
        let right_schema = right.schema();
170
52
        check_join_is_valid(&left_schema, &right_schema, &[])
?0
;
171
52
        let (schema, column_indices) =
172
52
            build_join_schema(&left_schema, &right_schema, join_type);
173
52
        let schema = Arc::new(schema);
174
52
        let cache =
175
52
            Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type);
176
52
177
52
        Ok(NestedLoopJoinExec {
178
52
            left,
179
52
            right,
180
52
            filter,
181
52
            join_type: *join_type,
182
52
            schema,
183
52
            inner_table: Default::default(),
184
52
            column_indices,
185
52
            metrics: Default::default(),
186
52
            cache,
187
52
        })
188
52
    }
189
190
    /// left side
191
0
    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
192
0
        &self.left
193
0
    }
194
195
    /// right side
196
52
    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
197
52
        &self.right
198
52
    }
199
200
    /// Filters applied before join output
201
0
    pub fn filter(&self) -> Option<&JoinFilter> {
202
0
        self.filter.as_ref()
203
0
    }
204
205
    /// How the join is performed
206
0
    pub fn join_type(&self) -> &JoinType {
207
0
        &self.join_type
208
0
    }
209
210
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
211
52
    fn compute_properties(
212
52
        left: &Arc<dyn ExecutionPlan>,
213
52
        right: &Arc<dyn ExecutionPlan>,
214
52
        schema: SchemaRef,
215
52
        join_type: JoinType,
216
52
    ) -> PlanProperties {
217
52
        // Calculate equivalence properties:
218
52
        let eq_properties = join_equivalence_properties(
219
52
            left.equivalence_properties().clone(),
220
52
            right.equivalence_properties().clone(),
221
52
            &join_type,
222
52
            schema,
223
52
            &Self::maintains_input_order(join_type),
224
52
            None,
225
52
            // No on columns in nested loop join
226
52
            &[],
227
52
        );
228
52
229
52
        let output_partitioning =
230
52
            asymmetric_join_output_partitioning(left, right, &join_type);
231
52
232
52
        // Determine execution mode:
233
52
        let mut mode = execution_mode_from_children([left, right]);
234
52
        if mode.is_unbounded() {
235
0
            mode = ExecutionMode::PipelineBreaking;
236
52
        }
237
238
52
        PlanProperties::new(eq_properties, output_partitioning, mode)
239
52
    }
240
241
    /// Returns a vector indicating whether the left and right inputs maintain their order.
242
    /// The first element corresponds to the left input, and the second to the right.
243
    ///
244
    /// The left (build-side) input's order may change, but the right (probe-side) input's
245
    /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins.
246
    ///
247
    /// Maintaining the right input's order helps optimize the nodes down the pipeline
248
    /// (See [`ExecutionPlan::maintains_input_order`]).
249
    ///
250
    /// This is a separate method because it is also called when computing properties, before
251
    /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as
252
    /// opposed to `Self`, for the same reason.
253
164
    fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
254
164
        vec![
255
            false,
256
28
            matches!(
257
164
                join_type,
258
                JoinType::Inner
259
                    | JoinType::Right
260
                    | JoinType::RightAnti
261
                    | JoinType::RightSemi
262
            ),
263
        ]
264
164
    }
265
}
266
267
impl DisplayAs for NestedLoopJoinExec {
268
0
    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
269
0
        match t {
270
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
271
0
                let display_filter = self.filter.as_ref().map_or_else(
272
0
                    || "".to_string(),
273
0
                    |f| format!(", filter={}", f.expression()),
274
0
                );
275
0
                write!(
276
0
                    f,
277
0
                    "NestedLoopJoinExec: join_type={:?}{}",
278
0
                    self.join_type, display_filter
279
0
                )
280
0
            }
281
0
        }
282
0
    }
283
}
284
285
impl ExecutionPlan for NestedLoopJoinExec {
286
0
    fn name(&self) -> &'static str {
287
0
        "NestedLoopJoinExec"
288
0
    }
289
290
0
    fn as_any(&self) -> &dyn Any {
291
0
        self
292
0
    }
293
294
52
    fn properties(&self) -> &PlanProperties {
295
52
        &self.cache
296
52
    }
297
298
0
    fn required_input_distribution(&self) -> Vec<Distribution> {
299
0
        vec![
300
0
            Distribution::SinglePartition,
301
0
            Distribution::UnspecifiedDistribution,
302
0
        ]
303
0
    }
304
305
112
    fn maintains_input_order(&self) -> Vec<bool> {
306
112
        Self::maintains_input_order(self.join_type)
307
112
    }
308
309
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
310
0
        vec![&self.left, &self.right]
311
0
    }
312
313
0
    fn with_new_children(
314
0
        self: Arc<Self>,
315
0
        children: Vec<Arc<dyn ExecutionPlan>>,
316
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
317
0
        Ok(Arc::new(NestedLoopJoinExec::try_new(
318
0
            Arc::clone(&children[0]),
319
0
            Arc::clone(&children[1]),
320
0
            self.filter.clone(),
321
0
            &self.join_type,
322
0
        )?))
323
0
    }
324
325
76
    fn execute(
326
76
        &self,
327
76
        partition: usize,
328
76
        context: Arc<TaskContext>,
329
76
    ) -> Result<SendableRecordBatchStream> {
330
76
        let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
331
76
332
76
        // Initialization reservation for load of inner table
333
76
        let load_reservation =
334
76
            MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]"))
335
76
                .register(context.memory_pool());
336
76
337
76
        let inner_table = self.inner_table.once(|| {
338
52
            collect_left_input(
339
52
                Arc::clone(&self.left),
340
52
                Arc::clone(&context),
341
52
                join_metrics.clone(),
342
52
                load_reservation,
343
52
                need_produce_result_in_final(self.join_type),
344
52
                self.right().output_partitioning().partition_count(),
345
52
            )
346
76
        });
347
348
76
        let outer_table = self.right.execute(partition, context)
?0
;
349
350
76
        let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0));
351
352
        // Right side has an order and it is maintained during operation.
353
76
        let right_side_ordered =
354
76
            self.maintains_input_order()[1] && 
self.right.output_ordering().is_some()56
;
355
76
        Ok(Box::pin(NestedLoopJoinStream {
356
76
            schema: Arc::clone(&self.schema),
357
76
            filter: self.filter.clone(),
358
76
            join_type: self.join_type,
359
76
            outer_table,
360
76
            inner_table,
361
76
            is_exhausted: false,
362
76
            column_indices: self.column_indices.clone(),
363
76
            join_metrics,
364
76
            indices_cache,
365
76
            right_side_ordered,
366
76
        }))
367
76
    }
368
369
0
    fn metrics(&self) -> Option<MetricsSet> {
370
0
        Some(self.metrics.clone_inner())
371
0
    }
372
373
0
    fn statistics(&self) -> Result<Statistics> {
374
0
        estimate_join_statistics(
375
0
            Arc::clone(&self.left),
376
0
            Arc::clone(&self.right),
377
0
            vec![],
378
0
            &self.join_type,
379
0
            &self.schema,
380
0
        )
381
0
    }
382
}
383
384
/// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it
385
52
async fn collect_left_input(
386
52
    input: Arc<dyn ExecutionPlan>,
387
52
    context: Arc<TaskContext>,
388
52
    join_metrics: BuildProbeJoinMetrics,
389
52
    reservation: MemoryReservation,
390
52
    with_visited_left_side: bool,
391
52
    probe_threads_count: usize,
392
52
) -> Result<JoinLeftData> {
393
52
    let schema = input.schema();
394
52
    let merge = if input.output_partitioning().partition_count() != 1 {
395
0
        Arc::new(CoalescePartitionsExec::new(input))
396
    } else {
397
52
        input
398
    };
399
52
    let stream = merge.execute(0, context)
?0
;
400
401
    // Load all batches and count the rows
402
52
    let (
batches, metrics, mut reservation44
) = stream
403
52
        .try_fold(
404
52
            (Vec::new(), join_metrics, reservation),
405
12.1k
            |mut acc, batch| async {
406
12.1k
                let batch_size = batch.get_array_memory_size();
407
12.1k
                // Reserve memory for incoming batch
408
12.1k
                acc.2.try_grow(batch_size)
?8
;
409
                // Update metrics
410
12.1k
                acc.1.build_mem_used.add(batch_size);
411
12.1k
                acc.1.build_input_batches.add(1);
412
12.1k
                acc.1.build_input_rows.add(batch.num_rows());
413
12.1k
                // Push batch to output
414
12.1k
                acc.0.push(batch);
415
12.1k
                Ok(acc)
416
24.2k
            },
417
52
        )
418
8
        .
await0
?;
419
420
44
    let merged_batch = concat_batches(&schema, &batches)
?0
;
421
422
    // Reserve memory for visited_left_side bitmap if required by join type
423
44
    let visited_left_side = if with_visited_left_side {
424
        // TODO: Replace `ceil` wrapper with stable `div_cell` after
425
        // https://github.com/rust-lang/rust/issues/88581
426
4
        let buffer_size = bit_util::ceil(merged_batch.num_rows(), 8);
427
4
        reservation.try_grow(buffer_size)
?0
;
428
4
        metrics.build_mem_used.add(buffer_size);
429
4
430
4
        let mut buffer = BooleanBufferBuilder::new(merged_batch.num_rows());
431
4
        buffer.append_n(merged_batch.num_rows(), false);
432
4
        buffer
433
    } else {
434
40
        BooleanBufferBuilder::new(0)
435
    };
436
437
44
    Ok(JoinLeftData::new(
438
44
        merged_batch,
439
44
        Mutex::new(visited_left_side),
440
44
        AtomicUsize::new(probe_threads_count),
441
44
        reservation,
442
44
    ))
443
52
}
444
445
/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
446
struct NestedLoopJoinStream {
447
    /// Input schema
448
    schema: Arc<Schema>,
449
    /// join filter
450
    filter: Option<JoinFilter>,
451
    /// type of the join
452
    join_type: JoinType,
453
    /// the outer table data of the nested loop join
454
    outer_table: SendableRecordBatchStream,
455
    /// the inner table data of the nested loop join
456
    inner_table: OnceFut<JoinLeftData>,
457
    /// There is nothing to process anymore and left side is processed in case of full join
458
    is_exhausted: bool,
459
    /// Information of index and left / right placement of columns
460
    column_indices: Vec<ColumnIndex>,
461
    // TODO: support null aware equal
462
    // null_equals_null: bool
463
    /// Join execution metrics
464
    join_metrics: BuildProbeJoinMetrics,
465
    /// Cache for join indices calculations
466
    indices_cache: (UInt64Array, UInt32Array),
467
    /// Whether the right side is ordered
468
    right_side_ordered: bool,
469
}
470
471
/// Creates a Cartesian product of two input batches, preserving the order of the right batch,
472
/// and applying a join filter if provided.
473
///
474
/// # Example
475
/// Input:
476
/// left = [0, 1], right = [0, 1, 2]
477
///
478
/// Output:
479
/// left_indices = [0, 1, 0, 1, 0, 1], right_indices = [0, 0, 1, 1, 2, 2]
480
///
481
/// Input:
482
/// left = [0, 1, 2], right = [0, 1, 2, 3], filter = left.a != right.a
483
///
484
/// Output:
485
/// left_indices = [1, 2, 0, 2, 0, 1, 0, 1, 2], right_indices = [0, 0, 1, 1, 2, 2, 3, 3, 3]
486
12.1k
fn build_join_indices(
487
12.1k
    left_batch: &RecordBatch,
488
12.1k
    right_batch: &RecordBatch,
489
12.1k
    filter: Option<&JoinFilter>,
490
12.1k
    indices_cache: &mut (UInt64Array, UInt32Array),
491
12.1k
) -> Result<(UInt64Array, UInt32Array)> {
492
12.1k
    let left_row_count = left_batch.num_rows();
493
12.1k
    let right_row_count = right_batch.num_rows();
494
12.1k
    let output_row_count = left_row_count * right_row_count;
495
12.1k
496
12.1k
    // We always use the same indices before applying the filter, so we can cache them
497
12.1k
    let (left_indices_cache, right_indices_cache) = indices_cache;
498
12.1k
    let cached_output_row_count = left_indices_cache.len();
499
500
12.1k
    let (left_indices, right_indices) =
501
12.1k
        match output_row_count.cmp(&cached_output_row_count) {
502
            std::cmp::Ordering::Equal => {
503
                // Reuse the cached indices
504
12.0k
                (left_indices_cache.clone(), right_indices_cache.clone())
505
            }
506
            std::cmp::Ordering::Less => {
507
                // Left_row_count never changes because it's the build side. The changes to the
508
                // right_row_count can be handled trivially by taking the first output_row_count
509
                // elements of the cache because of how the indices are generated.
510
                // (See the Ordering::Greater match arm)
511
0
                (
512
0
                    left_indices_cache.slice(0, output_row_count),
513
0
                    right_indices_cache.slice(0, output_row_count),
514
0
                )
515
            }
516
            std::cmp::Ordering::Greater => {
517
                // Rebuild the indices cache
518
519
                // Produces 0, 1, 2, 0, 1, 2, 0, 1, 2, ...
520
44
                *left_indices_cache = UInt64Array::from_iter_values(
521
13.2M
                    (0..output_row_count as u64).map(|i| i % left_row_count as u64),
522
44
                );
523
44
524
44
                // Produces 0, 0, 0, 1, 1, 1, 2, 2, 2, ...
525
44
                *right_indices_cache = UInt32Array::from_iter_values(
526
13.2M
                    (0..output_row_count as u32).map(|i| i / left_row_count as u32),
527
44
                );
528
44
529
44
                (left_indices_cache.clone(), right_indices_cache.clone())
530
            }
531
        };
532
533
12.1k
    if let Some(filter) = filter {
534
12.1k
        apply_join_filter_to_indices(
535
12.1k
            left_batch,
536
12.1k
            right_batch,
537
12.1k
            left_indices,
538
12.1k
            right_indices,
539
12.1k
            filter,
540
12.1k
            JoinSide::Left,
541
12.1k
        )
542
    } else {
543
0
        Ok((left_indices, right_indices))
544
    }
545
12.1k
}
546
547
impl NestedLoopJoinStream {
548
12.2k
    fn poll_next_impl(
549
12.2k
        &mut self,
550
12.2k
        cx: &mut std::task::Context<'_>,
551
12.2k
    ) -> Poll<Option<Result<RecordBatch>>> {
552
12.2k
        // all left row
553
12.2k
        let build_timer = self.join_metrics.build_time.timer();
554
12.2k
        let 
left_data12.2k
= match
ready!0
(self.inner_table.get_shared(cx)) {
555
12.2k
            Ok(data) => data,
556
8
            Err(e) => return Poll::Ready(Some(Err(e))),
557
        };
558
12.2k
        build_timer.done();
559
12.2k
560
12.2k
        // Get or initialize visited_left_side bitmap if required by join type
561
12.2k
        let visited_left_side = left_data.bitmap();
562
12.2k
563
12.2k
        // Check is_exhausted before polling the outer_table, such that when the outer table
564
12.2k
        // does not support `FusedStream`, Self will not poll it again
565
12.2k
        if self.is_exhausted {
566
4
            return Poll::Ready(None);
567
12.2k
        }
568
12.2k
569
12.2k
        self.outer_table
570
12.2k
            .poll_next_unpin(cx)
571
12.2k
            .map(|maybe_batch| 
m12.2k
atch
maybe_batch12.1k
{
572
12.1k
                Some(Ok(right_batch)) => {
573
12.1k
                    // Setting up timer & updating input metrics
574
12.1k
                    self.join_metrics.input_batches.add(1);
575
12.1k
                    self.join_metrics.input_rows.add(right_batch.num_rows());
576
12.1k
                    let timer = self.join_metrics.join_time.timer();
577
12.1k
578
12.1k
                    let result = join_left_and_right_batch(
579
12.1k
                        left_data.batch(),
580
12.1k
                        &right_batch,
581
12.1k
                        self.join_type,
582
12.1k
                        self.filter.as_ref(),
583
12.1k
                        &self.column_indices,
584
12.1k
                        &self.schema,
585
12.1k
                        visited_left_side,
586
12.1k
                        &mut self.indices_cache,
587
12.1k
                        self.right_side_ordered,
588
12.1k
                    );
589
590
                    // Recording time & updating output metrics
591
12.1k
                    if let Ok(batch) = &result {
592
12.1k
                        timer.done();
593
12.1k
                        self.join_metrics.output_batches.add(1);
594
12.1k
                        self.join_metrics.output_rows.add(batch.num_rows());
595
12.1k
                    }
0
596
597
12.1k
                    Some(result)
598
                }
599
0
                Some(err) => Some(err),
600
                None => {
601
68
                    if need_produce_result_in_final(self.join_type) {
602
                        // At this stage `visited_left_side` won't be updated, so it's
603
                        // safe to report about probe completion.
604
                        //
605
                        // Setting `is_exhausted` / returning None will prevent from
606
                        // multiple calls of `report_probe_completed()`
607
16
                        if !left_data.report_probe_completed() {
608
12
                            self.is_exhausted = true;
609
12
                            return None;
610
4
                        };
611
4
612
4
                        // Only setting up timer, input is exhausted
613
4
                        let timer = self.join_metrics.join_time.timer();
614
4
                        // use the global left bitmap to produce the left indices and right indices
615
4
                        let (left_side, right_side) =
616
4
                            get_final_indices_from_shared_bitmap(
617
4
                                visited_left_side,
618
4
                                self.join_type,
619
4
                            );
620
4
                        let empty_right_batch =
621
4
                            RecordBatch::new_empty(self.outer_table.schema());
622
4
                        // use the left and right indices to produce the batch result
623
4
                        let result = build_batch_from_indices(
624
4
                            &self.schema,
625
4
                            left_data.batch(),
626
4
                            &empty_right_batch,
627
4
                            &left_side,
628
4
                            &right_side,
629
4
                            &self.column_indices,
630
4
                            JoinSide::Left,
631
4
                        );
632
4
                        self.is_exhausted = true;
633
634
                        // Recording time & updating output metrics
635
4
                        if let Ok(batch) = &result {
636
4
                            timer.done();
637
4
                            self.join_metrics.output_batches.add(1);
638
4
                            self.join_metrics.output_rows.add(batch.num_rows());
639
4
                        }
0
640
641
4
                        Some(result)
642
                    } else {
643
                        // end of the join loop
644
52
                        None
645
                    }
646
                }
647
12.2k
            
}12.2k
)
648
12.2k
    }
649
}
650
651
#[allow(clippy::too_many_arguments)]
652
12.1k
fn join_left_and_right_batch(
653
12.1k
    left_batch: &RecordBatch,
654
12.1k
    right_batch: &RecordBatch,
655
12.1k
    join_type: JoinType,
656
12.1k
    filter: Option<&JoinFilter>,
657
12.1k
    column_indices: &[ColumnIndex],
658
12.1k
    schema: &Schema,
659
12.1k
    visited_left_side: &SharedBitmapBuilder,
660
12.1k
    indices_cache: &mut (UInt64Array, UInt32Array),
661
12.1k
    right_side_ordered: bool,
662
12.1k
) -> Result<RecordBatch> {
663
12.1k
    let (left_side, right_side) =
664
12.1k
        build_join_indices(left_batch, right_batch, filter, indices_cache).map_err(
665
12.1k
            |e| {
666
0
                exec_datafusion_err!(
667
0
                    "Fail to build join indices in NestedLoopJoinExec, error: {e}"
668
0
                )
669
12.1k
            },
670
12.1k
        )
?0
;
671
672
    // set the left bitmap
673
    // and only full join need the left bitmap
674
12.1k
    if need_produce_result_in_final(join_type) {
675
4
        let mut bitmap = visited_left_side.lock();
676
4
        left_side.values().iter().for_each(|x| {
677
4
            bitmap.set_bit(*x as usize, true);
678
4
        });
679
12.1k
    }
680
    // adjust the two side indices base on the join type
681
12.1k
    let (left_side, right_side) = adjust_indices_by_join_type(
682
12.1k
        left_side,
683
12.1k
        right_side,
684
12.1k
        0..right_batch.num_rows(),
685
12.1k
        join_type,
686
12.1k
        right_side_ordered,
687
12.1k
    );
688
12.1k
689
12.1k
    build_batch_from_indices(
690
12.1k
        schema,
691
12.1k
        left_batch,
692
12.1k
        right_batch,
693
12.1k
        &left_side,
694
12.1k
        &right_side,
695
12.1k
        column_indices,
696
12.1k
        JoinSide::Left,
697
12.1k
    )
698
12.1k
}
699
700
4
fn get_final_indices_from_shared_bitmap(
701
4
    shared_bitmap: &SharedBitmapBuilder,
702
4
    join_type: JoinType,
703
4
) -> (UInt64Array, UInt32Array) {
704
4
    let bitmap = shared_bitmap.lock();
705
4
    get_final_indices_from_bit_map(&bitmap, join_type)
706
4
}
707
708
impl Stream for NestedLoopJoinStream {
709
    type Item = Result<RecordBatch>;
710
711
12.2k
    fn poll_next(
712
12.2k
        mut self: std::pin::Pin<&mut Self>,
713
12.2k
        cx: &mut std::task::Context<'_>,
714
12.2k
    ) -> Poll<Option<Self::Item>> {
715
12.2k
        self.poll_next_impl(cx)
716
12.2k
    }
717
}
718
719
impl RecordBatchStream for NestedLoopJoinStream {
720
0
    fn schema(&self) -> SchemaRef {
721
0
        Arc::clone(&self.schema)
722
0
    }
723
}
724
725
#[cfg(test)]
726
mod tests {
727
    use super::*;
728
    use crate::{
729
        common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec,
730
        test::build_table_i32,
731
    };
732
733
    use arrow::datatypes::{DataType, Field};
734
    use arrow_array::Int32Array;
735
    use arrow_schema::SortOptions;
736
    use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
737
    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
738
    use datafusion_expr::Operator;
739
    use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
740
    use datafusion_physical_expr::{Partitioning, PhysicalExpr};
741
    use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
742
743
    use rstest::rstest;
744
745
90
    fn build_table(
746
90
        a: (&str, &Vec<i32>),
747
90
        b: (&str, &Vec<i32>),
748
90
        c: (&str, &Vec<i32>),
749
90
        batch_size: Option<usize>,
750
90
        sorted_column_names: Vec<&str>,
751
90
    ) -> Arc<dyn ExecutionPlan> {
752
90
        let batch = build_table_i32(a, b, c);
753
90
        let schema = batch.schema();
754
755
90
        let batches = if let Some(
batch_size72
) = batch_size {
756
72
            let num_batches = batch.num_rows().div_ceil(batch_size);
757
72
            (0..num_batches)
758
24.2k
                .map(|i| {
759
24.2k
                    let start = i * batch_size;
760
24.2k
                    let remaining_rows = batch.num_rows() - start;
761
24.2k
                    batch.slice(start, batch_size.min(remaining_rows))
762
24.2k
                })
763
72
                .collect::<Vec<_>>()
764
        } else {
765
18
            vec![batch]
766
        };
767
768
90
        let mut exec =
769
90
            MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap();
770
90
        if !sorted_column_names.is_empty() {
771
36
            let mut sort_info = Vec::new();
772
144
            for 
name108
in sorted_column_names {
773
108
                let index = schema.index_of(name).unwrap();
774
108
                let sort_expr = PhysicalSortExpr {
775
108
                    expr: Arc::new(Column::new(name, index)),
776
108
                    options: SortOptions {
777
108
                        descending: false,
778
108
                        nulls_first: false,
779
108
                    },
780
108
                };
781
108
                sort_info.push(sort_expr);
782
108
            }
783
36
            exec = exec.with_sort_information(vec![sort_info]);
784
54
        }
785
786
90
        Arc::new(exec)
787
90
    }
788
789
8
    fn build_left_table() -> Arc<dyn ExecutionPlan> {
790
8
        build_table(
791
8
            ("a1", &vec![5, 9, 11]),
792
8
            ("b1", &vec![5, 8, 8]),
793
8
            ("c1", &vec![50, 90, 110]),
794
8
            None,
795
8
            Vec::new(),
796
8
        )
797
8
    }
798
799
8
    fn build_right_table() -> Arc<dyn ExecutionPlan> {
800
8
        build_table(
801
8
            ("a2", &vec![12, 2, 10]),
802
8
            ("b2", &vec![10, 2, 10]),
803
8
            ("c2", &vec![40, 80, 100]),
804
8
            None,
805
8
            Vec::new(),
806
8
        )
807
8
    }
808
809
9
    fn prepare_join_filter() -> JoinFilter {
810
9
        let column_indices = vec![
811
9
            ColumnIndex {
812
9
                index: 1,
813
9
                side: JoinSide::Left,
814
9
            },
815
9
            ColumnIndex {
816
9
                index: 1,
817
9
                side: JoinSide::Right,
818
9
            },
819
9
        ];
820
9
        let intermediate_schema = Schema::new(vec![
821
9
            Field::new("x", DataType::Int32, true),
822
9
            Field::new("x", DataType::Int32, true),
823
9
        ]);
824
9
        // left.b1!=8
825
9
        let left_filter = Arc::new(BinaryExpr::new(
826
9
            Arc::new(Column::new("x", 0)),
827
9
            Operator::NotEq,
828
9
            Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
829
9
        )) as Arc<dyn PhysicalExpr>;
830
9
        // right.b2!=10
831
9
        let right_filter = Arc::new(BinaryExpr::new(
832
9
            Arc::new(Column::new("x", 1)),
833
9
            Operator::NotEq,
834
9
            Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
835
9
        )) as Arc<dyn PhysicalExpr>;
836
9
        // filter = left.b1!=8 and right.b2!=10
837
9
        // after filter:
838
9
        // left table:
839
9
        // ("a1", &vec![5]),
840
9
        // ("b1", &vec![5]),
841
9
        // ("c1", &vec![50]),
842
9
        // right table:
843
9
        // ("a2", &vec![12, 2]),
844
9
        // ("b2", &vec![10, 2]),
845
9
        // ("c2", &vec![40, 80]),
846
9
        let filter_expression =
847
9
            Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
848
9
                as Arc<dyn PhysicalExpr>;
849
9
850
9
        JoinFilter::new(filter_expression, column_indices, intermediate_schema)
851
9
    }
852
853
16
    async fn multi_partitioned_join_collect(
854
16
        left: Arc<dyn ExecutionPlan>,
855
16
        right: Arc<dyn ExecutionPlan>,
856
16
        join_type: &JoinType,
857
16
        join_filter: Option<JoinFilter>,
858
16
        context: Arc<TaskContext>,
859
16
    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
860
16
        let partition_count = 4;
861
862
        // Redistributing right input
863
16
        let right = Arc::new(RepartitionExec::try_new(
864
16
            right,
865
16
            Partitioning::RoundRobinBatch(partition_count),
866
16
        )
?0
) as Arc<dyn ExecutionPlan>;
867
868
        // Use the required distribution for nested loop join to test partition data
869
16
        let nested_loop_join =
870
16
            NestedLoopJoinExec::try_new(left, right, join_filter, join_type)
?0
;
871
16
        let columns = columns(&nested_loop_join.schema());
872
16
        let mut batches = vec![];
873
40
        for i in 0..
partition_count16
{
874
40
            let stream = nested_loop_join.execute(i, Arc::clone(&context))
?0
;
875
40
            let 
more_batches32
= common::collect(stream).
await8
?8
;
876
32
            batches.extend(
877
32
                more_batches
878
32
                    .into_iter()
879
32
                    .filter(|b| 
b.num_rows() > 012
)
880
32
                    .collect::<Vec<_>>(),
881
32
            );
882
32
        }
883
8
        Ok((columns, batches))
884
16
    }
885
886
    #[tokio::test]
887
1
    async fn join_inner_with_filter() -> Result<()> {
888
1
        let task_ctx = Arc::new(TaskContext::default());
889
1
        let left = build_left_table();
890
1
        let right = build_right_table();
891
1
        let filter = prepare_join_filter();
892
1
        let (columns, batches) = multi_partitioned_join_collect(
893
1
            left,
894
1
            right,
895
1
            &JoinType::Inner,
896
1
            Some(filter),
897
1
            task_ctx,
898
1
        )
899
1
        .await
?0
;
900
1
        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
901
1
        let expected = [
902
1
            "+----+----+----+----+----+----+",
903
1
            "| a1 | b1 | c1 | a2 | b2 | c2 |",
904
1
            "+----+----+----+----+----+----+",
905
1
            "| 5  | 5  | 50 | 2  | 2  | 80 |",
906
1
            "+----+----+----+----+----+----+",
907
1
        ];
908
1
909
1
        assert_batches_sorted_eq!(expected, &batches);
910
1
911
1
        Ok(())
912
1
    }
913
914
    #[tokio::test]
915
1
    async fn join_left_with_filter() -> Result<()> {
916
1
        let task_ctx = Arc::new(TaskContext::default());
917
1
        let left = build_left_table();
918
1
        let right = build_right_table();
919
1
920
1
        let filter = prepare_join_filter();
921
1
        let (columns, batches) = multi_partitioned_join_collect(
922
1
            left,
923
1
            right,
924
1
            &JoinType::Left,
925
1
            Some(filter),
926
1
            task_ctx,
927
1
        )
928
1
        .await
?0
;
929
1
        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
930
1
        let expected = [
931
1
            "+----+----+-----+----+----+----+",
932
1
            "| a1 | b1 | c1  | a2 | b2 | c2 |",
933
1
            "+----+----+-----+----+----+----+",
934
1
            "| 11 | 8  | 110 |    |    |    |",
935
1
            "| 5  | 5  | 50  | 2  | 2  | 80 |",
936
1
            "| 9  | 8  | 90  |    |    |    |",
937
1
            "+----+----+-----+----+----+----+",
938
1
        ];
939
1
940
1
        assert_batches_sorted_eq!(expected, &batches);
941
1
942
1
        Ok(())
943
1
    }
944
945
    #[tokio::test]
946
1
    async fn join_right_with_filter() -> Result<()> {
947
1
        let task_ctx = Arc::new(TaskContext::default());
948
1
        let left = build_left_table();
949
1
        let right = build_right_table();
950
1
951
1
        let filter = prepare_join_filter();
952
1
        let (columns, batches) = multi_partitioned_join_collect(
953
1
            left,
954
1
            right,
955
1
            &JoinType::Right,
956
1
            Some(filter),
957
1
            task_ctx,
958
1
        )
959
1
        .await
?0
;
960
1
        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
961
1
        let expected = [
962
1
            "+----+----+----+----+----+-----+",
963
1
            "| a1 | b1 | c1 | a2 | b2 | c2  |",
964
1
            "+----+----+----+----+----+-----+",
965
1
            "|    |    |    | 10 | 10 | 100 |",
966
1
            "|    |    |    | 12 | 10 | 40  |",
967
1
            "| 5  | 5  | 50 | 2  | 2  | 80  |",
968
1
            "+----+----+----+----+----+-----+",
969
1
        ];
970
1
971
1
        assert_batches_sorted_eq!(expected, &batches);
972
1
973
1
        Ok(())
974
1
    }
975
976
    #[tokio::test]
977
1
    async fn join_full_with_filter() -> Result<()> {
978
1
        let task_ctx = Arc::new(TaskContext::default());
979
1
        let left = build_left_table();
980
1
        let right = build_right_table();
981
1
982
1
        let filter = prepare_join_filter();
983
1
        let (columns, batches) = multi_partitioned_join_collect(
984
1
            left,
985
1
            right,
986
1
            &JoinType::Full,
987
1
            Some(filter),
988
1
            task_ctx,
989
1
        )
990
1
        .await
?0
;
991
1
        assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
992
1
        let expected = [
993
1
            "+----+----+-----+----+----+-----+",
994
1
            "| a1 | b1 | c1  | a2 | b2 | c2  |",
995
1
            "+----+----+-----+----+----+-----+",
996
1
            "|    |    |     | 10 | 10 | 100 |",
997
1
            "|    |    |     | 12 | 10 | 40  |",
998
1
            "| 11 | 8  | 110 |    |    |     |",
999
1
            "| 5  | 5  | 50  | 2  | 2  | 80  |",
1000
1
            "| 9  | 8  | 90  |    |    |     |",
1001
1
            "+----+----+-----+----+----+-----+",
1002
1
        ];
1003
1
1004
1
        assert_batches_sorted_eq!(expected, &batches);
1005
1
1006
1
        Ok(())
1007
1
    }
1008
1009
    #[tokio::test]
1010
1
    async fn join_left_semi_with_filter() -> Result<()> {
1011
1
        let task_ctx = Arc::new(TaskContext::default());
1012
1
        let left = build_left_table();
1013
1
        let right = build_right_table();
1014
1
1015
1
        let filter = prepare_join_filter();
1016
1
        let (columns, batches) = multi_partitioned_join_collect(
1017
1
            left,
1018
1
            right,
1019
1
            &JoinType::LeftSemi,
1020
1
            Some(filter),
1021
1
            task_ctx,
1022
1
        )
1023
1
        .await
?0
;
1024
1
        assert_eq!(columns, vec!["a1", "b1", "c1"]);
1025
1
        let expected = [
1026
1
            "+----+----+----+",
1027
1
            "| a1 | b1 | c1 |",
1028
1
            "+----+----+----+",
1029
1
            "| 5  | 5  | 50 |",
1030
1
            "+----+----+----+",
1031
1
        ];
1032
1
1033
1
        assert_batches_sorted_eq!(expected, &batches);
1034
1
1035
1
        Ok(())
1036
1
    }
1037
1038
    #[tokio::test]
1039
1
    async fn join_left_anti_with_filter() -> Result<()> {
1040
1
        let task_ctx = Arc::new(TaskContext::default());
1041
1
        let left = build_left_table();
1042
1
        let right = build_right_table();
1043
1
1044
1
        let filter = prepare_join_filter();
1045
1
        let (columns, batches) = multi_partitioned_join_collect(
1046
1
            left,
1047
1
            right,
1048
1
            &JoinType::LeftAnti,
1049
1
            Some(filter),
1050
1
            task_ctx,
1051
1
        )
1052
1
        .await
?0
;
1053
1
        assert_eq!(columns, vec!["a1", "b1", "c1"]);
1054
1
        let expected = [
1055
1
            "+----+----+-----+",
1056
1
            "| a1 | b1 | c1  |",
1057
1
            "+----+----+-----+",
1058
1
            "| 11 | 8  | 110 |",
1059
1
            "| 9  | 8  | 90  |",
1060
1
            "+----+----+-----+",
1061
1
        ];
1062
1
1063
1
        assert_batches_sorted_eq!(expected, &batches);
1064
1
1065
1
        Ok(())
1066
1
    }
1067
1068
    #[tokio::test]
1069
1
    async fn join_right_semi_with_filter() -> Result<()> {
1070
1
        let task_ctx = Arc::new(TaskContext::default());
1071
1
        let left = build_left_table();
1072
1
        let right = build_right_table();
1073
1
1074
1
        let filter = prepare_join_filter();
1075
1
        let (columns, batches) = multi_partitioned_join_collect(
1076
1
            left,
1077
1
            right,
1078
1
            &JoinType::RightSemi,
1079
1
            Some(filter),
1080
1
            task_ctx,
1081
1
        )
1082
1
        .await
?0
;
1083
1
        assert_eq!(columns, vec!["a2", "b2", "c2"]);
1084
1
        let expected = [
1085
1
            "+----+----+----+",
1086
1
            "| a2 | b2 | c2 |",
1087
1
            "+----+----+----+",
1088
1
            "| 2  | 2  | 80 |",
1089
1
            "+----+----+----+",
1090
1
        ];
1091
1
1092
1
        assert_batches_sorted_eq!(expected, &batches);
1093
1
1094
1
        Ok(())
1095
1
    }
1096
1097
    #[tokio::test]
1098
1
    async fn join_right_anti_with_filter() -> Result<()> {
1099
1
        let task_ctx = Arc::new(TaskContext::default());
1100
1
        let left = build_left_table();
1101
1
        let right = build_right_table();
1102
1
1103
1
        let filter = prepare_join_filter();
1104
1
        let (columns, batches) = multi_partitioned_join_collect(
1105
1
            left,
1106
1
            right,
1107
1
            &JoinType::RightAnti,
1108
1
            Some(filter),
1109
1
            task_ctx,
1110
1
        )
1111
1
        .await
?0
;
1112
1
        assert_eq!(columns, vec!["a2", "b2", "c2"]);
1113
1
        let expected = [
1114
1
            "+----+----+-----+",
1115
1
            "| a2 | b2 | c2  |",
1116
1
            "+----+----+-----+",
1117
1
            "| 10 | 10 | 100 |",
1118
1
            "| 12 | 10 | 40  |",
1119
1
            "+----+----+-----+",
1120
1
        ];
1121
1
1122
1
        assert_batches_sorted_eq!(expected, &batches);
1123
1
1124
1
        Ok(())
1125
1
    }
1126
1127
    #[tokio::test]
1128
1
    async fn test_overallocation() -> Result<()> {
1129
1
        let left = build_table(
1130
1
            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1131
1
            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1132
1
            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
1133
1
            None,
1134
1
            Vec::new(),
1135
1
        );
1136
1
        let right = build_table(
1137
1
            ("a2", &vec![10, 11]),
1138
1
            ("b2", &vec![12, 13]),
1139
1
            ("c2", &vec![14, 15]),
1140
1
            None,
1141
1
            Vec::new(),
1142
1
        );
1143
1
        let filter = prepare_join_filter();
1144
1
1145
1
        let join_types = vec![
1146
1
            JoinType::Inner,
1147
1
            JoinType::Left,
1148
1
            JoinType::Right,
1149
1
            JoinType::Full,
1150
1
            JoinType::LeftSemi,
1151
1
            JoinType::LeftAnti,
1152
1
            JoinType::RightSemi,
1153
1
            JoinType::RightAnti,
1154
1
        ];
1155
1
1156
9
        for 
join_type8
in join_types {
1157
8
            let runtime = RuntimeEnvBuilder::new()
1158
8
                .with_memory_limit(100, 1.0)
1159
8
                .build_arc()
?0
;
1160
8
            let task_ctx = TaskContext::default().with_runtime(runtime);
1161
8
            let task_ctx = Arc::new(task_ctx);
1162
1
1163
8
            let err = multi_partitioned_join_collect(
1164
8
                Arc::clone(&left),
1165
8
                Arc::clone(&right),
1166
8
                &join_type,
1167
8
                Some(filter.clone()),
1168
8
                task_ctx,
1169
8
            )
1170
1
            .
await0
1171
8
            .unwrap_err();
1172
8
1173
8
            assert_contains!(
1174
8
                err.to_string(),
1175
8
                "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]"
1176
8
            );
1177
1
        }
1178
1
1179
1
        Ok(())
1180
1
    }
1181
1182
36
    fn prepare_mod_join_filter() -> JoinFilter {
1183
36
        let column_indices = vec![
1184
36
            ColumnIndex {
1185
36
                index: 1,
1186
36
                side: JoinSide::Left,
1187
36
            },
1188
36
            ColumnIndex {
1189
36
                index: 1,
1190
36
                side: JoinSide::Right,
1191
36
            },
1192
36
        ];
1193
36
        let intermediate_schema = Schema::new(vec![
1194
36
            Field::new("x", DataType::Int32, true),
1195
36
            Field::new("x", DataType::Int32, true),
1196
36
        ]);
1197
36
1198
36
        // left.b1 % 3
1199
36
        let left_mod = Arc::new(BinaryExpr::new(
1200
36
            Arc::new(Column::new("x", 0)),
1201
36
            Operator::Modulo,
1202
36
            Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
1203
36
        )) as Arc<dyn PhysicalExpr>;
1204
36
        // left.b1 % 3 != 0
1205
36
        let left_filter = Arc::new(BinaryExpr::new(
1206
36
            left_mod,
1207
36
            Operator::NotEq,
1208
36
            Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1209
36
        )) as Arc<dyn PhysicalExpr>;
1210
36
1211
36
        // right.b2 % 5
1212
36
        let right_mod = Arc::new(BinaryExpr::new(
1213
36
            Arc::new(Column::new("x", 1)),
1214
36
            Operator::Modulo,
1215
36
            Arc::new(Literal::new(ScalarValue::Int32(Some(5)))),
1216
36
        )) as Arc<dyn PhysicalExpr>;
1217
36
        // right.b2 % 5 != 0
1218
36
        let right_filter = Arc::new(BinaryExpr::new(
1219
36
            right_mod,
1220
36
            Operator::NotEq,
1221
36
            Arc::new(Literal::new(ScalarValue::Int32(Some(0)))),
1222
36
        )) as Arc<dyn PhysicalExpr>;
1223
36
        // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0
1224
36
        let filter_expression =
1225
36
            Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
1226
36
                as Arc<dyn PhysicalExpr>;
1227
36
1228
36
        JoinFilter::new(filter_expression, column_indices, intermediate_schema)
1229
36
    }
1230
1231
72
    fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> {
1232
72.0k
        let column = (1..=num_rows).map(|x| x as i32).collect();
1233
72
        vec![column; num_columns]
1234
72
    }
1235
1236
36
    #[rstest]
1237
    #[tokio::test]
1238
    async fn join_maintains_right_order(
1239
        #[values(
1240
            JoinType::Inner,
1241
            JoinType::Right,
1242
            JoinType::RightAnti,
1243
            JoinType::RightSemi
1244
        )]
1245
        join_type: JoinType,
1246
        #[values(1, 100, 1000)] left_batch_size: usize,
1247
        #[values(1, 100, 1000)] right_batch_size: usize,
1248
    ) -> Result<()> {
1249
        let left_columns = generate_columns(3, 1000);
1250
        let left = build_table(
1251
            ("a1", &left_columns[0]),
1252
            ("b1", &left_columns[1]),
1253
            ("c1", &left_columns[2]),
1254
            Some(left_batch_size),
1255
            Vec::new(),
1256
        );
1257
1258
        let right_columns = generate_columns(3, 1000);
1259
        let right = build_table(
1260
            ("a2", &right_columns[0]),
1261
            ("b2", &right_columns[1]),
1262
            ("c2", &right_columns[2]),
1263
            Some(right_batch_size),
1264
            vec!["a2", "b2", "c2"],
1265
        );
1266
1267
        let filter = prepare_mod_join_filter();
1268
1269
        let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new(
1270
            left,
1271
            Arc::clone(&right),
1272
            Some(filter),
1273
            &join_type,
1274
        )?) as Arc<dyn ExecutionPlan>;
1275
        assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]);
1276
1277
        let right_column_indices = match join_type {
1278
            JoinType::Inner | JoinType::Right => vec![3, 4, 5],
1279
            JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2],
1280
            _ => unreachable!(),
1281
        };
1282
1283
        let right_ordering = right.output_ordering().unwrap();
1284
        let join_ordering = nested_loop_join.output_ordering().unwrap();
1285
        for (right, join) in right_ordering.iter().zip(join_ordering.iter()) {
1286
            let right_column = right.expr.as_any().downcast_ref::<Column>().unwrap();
1287
            let join_column = join.expr.as_any().downcast_ref::<Column>().unwrap();
1288
            assert_eq!(join_column.name(), join_column.name());
1289
            assert_eq!(
1290
                right_column_indices[right_column.index()],
1291
                join_column.index()
1292
            );
1293
            assert_eq!(right.options, join.options);
1294
        }
1295
1296
        let batches = nested_loop_join
1297
            .execute(0, Arc::new(TaskContext::default()))?
1298
            .try_collect::<Vec<_>>()
1299
            .await?;
1300
1301
        // Make sure that the order of the right side is maintained
1302
        let mut prev_values = [i32::MIN, i32::MIN, i32::MIN];
1303
1304
        for (batch_index, batch) in batches.iter().enumerate() {
1305
            let columns: Vec<_> = right_column_indices
1306
                .iter()
1307
36.3k
                .map(|&i| {
1308
36.3k
                    batch
1309
36.3k
                        .column(i)
1310
36.3k
                        .as_any()
1311
36.3k
                        .downcast_ref::<Int32Array>()
1312
36.3k
                        .unwrap()
1313
36.3k
                })
1314
                .collect();
1315
1316
            for row in 0..batch.num_rows() {
1317
                let current_values = [
1318
                    columns[0].value(row),
1319
                    columns[1].value(row),
1320
                    columns[2].value(row),
1321
                ];
1322
                assert!(
1323
                    current_values
1324
                        .into_iter()
1325
                        .zip(prev_values)
1326
28.8M
                        .all(|(current, prev)| current >= prev),
1327
                    "batch_index: {} row: {} current: {:?}, prev: {:?}",
1328
                    batch_index,
1329
                    row,
1330
                    current_values,
1331
                    prev_values
1332
                );
1333
                prev_values = current_values;
1334
            }
1335
        }
1336
1337
        Ok(())
1338
    }
1339
1340
    /// Returns the column names on the schema
1341
16
    fn columns(schema: &Schema) -> Vec<String> {
1342
72
        schema.fields().iter().map(|f| f.name().clone()).collect()
1343
16
    }
1344
}