Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! This file implements the symmetric hash join algorithm with range-based
19
//! data pruning to join two (potentially infinite) streams.
20
//!
21
//! A [`SymmetricHashJoinExec`] plan takes two children plan (with appropriate
22
//! output ordering) and produces the join output according to the given join
23
//! type and other options.
24
//!
25
//! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations
26
//! for both its children.
27
28
use std::any::Any;
29
use std::fmt::{self, Debug};
30
use std::sync::Arc;
31
use std::task::{Context, Poll};
32
use std::vec;
33
34
use crate::common::SharedMemoryReservation;
35
use crate::handle_state;
36
use crate::joins::hash_join::{equal_rows_arr, update_hash};
37
use crate::joins::stream_join_utils::{
38
    calculate_filter_expr_intervals, combine_two_batches,
39
    convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
40
    get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices,
41
    PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics,
42
};
43
use crate::joins::utils::{
44
    apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
45
    check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter,
46
    JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult,
47
};
48
use crate::{
49
    execution_mode_from_children,
50
    expressions::PhysicalSortExpr,
51
    joins::StreamJoinPartitionMode,
52
    metrics::{ExecutionPlanMetricsSet, MetricsSet},
53
    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
54
    PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
55
};
56
57
use arrow::array::{
58
    ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder, UInt32Array,
59
    UInt64Array,
60
};
61
use arrow::compute::concat_batches;
62
use arrow::datatypes::{Schema, SchemaRef};
63
use arrow::record_batch::RecordBatch;
64
use datafusion_common::hash_utils::create_hashes;
65
use datafusion_common::utils::bisect;
66
use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result};
67
use datafusion_execution::memory_pool::MemoryConsumer;
68
use datafusion_execution::TaskContext;
69
use datafusion_expr::interval_arithmetic::Interval;
70
use datafusion_physical_expr::equivalence::join_equivalence_properties;
71
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
72
use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
73
74
use ahash::RandomState;
75
use datafusion_physical_expr_common::sort_expr::LexRequirement;
76
use futures::{ready, Stream, StreamExt};
77
use hashbrown::HashSet;
78
use parking_lot::Mutex;
79
80
const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
81
82
/// A symmetric hash join with range conditions is when both streams are hashed on the
83
/// join key and the resulting hash tables are used to join the streams.
84
/// The join is considered symmetric because the hash table is built on the join keys from both
85
/// streams, and the matching of rows is based on the values of the join keys in both streams.
86
/// This type of join is efficient in streaming context as it allows for fast lookups in the hash
87
/// table, rather than having to scan through one or both of the streams to find matching rows, also it
88
/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions),
89
/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming
90
/// data without any memory issues.
91
///
92
/// For each input stream, create a hash table.
93
///   - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets.
94
///   - Test if input is equal to a predefined set of other inputs.
95
///   - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch].
96
///   - Try to prune other side (probe) with new [RecordBatch].
97
///   - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.),
98
///     output the [RecordBatch] when a pruning happens or at the end of the data.
99
///
100
///
101
/// ``` text
102
///                        +-------------------------+
103
///                        |                         |
104
///   left stream ---------|  Left OneSideHashJoiner |---+
105
///                        |                         |   |
106
///                        +-------------------------+   |
107
///                                                      |
108
///                                                      |--------- Joined output
109
///                                                      |
110
///                        +-------------------------+   |
111
///                        |                         |   |
112
///  right stream ---------| Right OneSideHashJoiner |---+
113
///                        |                         |
114
///                        +-------------------------+
115
///
116
/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic
117
/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range.
118
///
119
///
120
///               PROBE SIDE          BUILD SIDE
121
///                 BUFFER              BUFFER
122
///             +-------------+     +------------+
123
///             |             |     |            |    Unjoinable
124
///             |             |     |            |    Range
125
///             |             |     |            |
126
///             |             |  |---------------------------------
127
///             |             |  |  |            |
128
///             |             |  |  |            |
129
///             |             | /   |            |
130
///             |             | |   |            |
131
///             |             | |   |            |
132
///             |             | |   |            |
133
///             |             | |   |            |
134
///             |             | |   |            |    Joinable
135
///             |             |/    |            |    Range
136
///             |             ||    |            |
137
///             |+-----------+||    |            |
138
///             || Record    ||     |            |
139
///             || Batch     ||     |            |
140
///             |+-----------+||    |            |
141
///             +-------------+\    +------------+
142
///                             |
143
///                             \
144
///                              |---------------------------------
145
///
146
///  This happens when range conditions are provided on sorted columns. E.g.
147
///
148
///        SELECT * FROM left_table, right_table
149
///        ON
150
///          left_key = right_key AND
151
///          left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR
152
///
153
/// or
154
///       SELECT * FROM left_table, right_table
155
///        ON
156
///          left_key = right_key AND
157
///          left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10
158
///
159
/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to
160
/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the
161
/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios)
162
/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning
163
/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" ,
164
/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending)
165
/// than that can be dropped from the inner buffer.
166
/// ```
167
#[derive(Debug)]
168
pub struct SymmetricHashJoinExec {
169
    /// Left side stream
170
    pub(crate) left: Arc<dyn ExecutionPlan>,
171
    /// Right side stream
172
    pub(crate) right: Arc<dyn ExecutionPlan>,
173
    /// Set of common columns used to join on
174
    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
175
    /// Filters applied when finding matching rows
176
    pub(crate) filter: Option<JoinFilter>,
177
    /// How the join is performed
178
    pub(crate) join_type: JoinType,
179
    /// Shares the `RandomState` for the hashing algorithm
180
    random_state: RandomState,
181
    /// Execution metrics
182
    metrics: ExecutionPlanMetricsSet,
183
    /// Information of index and left / right placement of columns
184
    column_indices: Vec<ColumnIndex>,
185
    /// If null_equals_null is true, null == null else null != null
186
    pub(crate) null_equals_null: bool,
187
    /// Left side sort expression(s)
188
    pub(crate) left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
189
    /// Right side sort expression(s)
190
    pub(crate) right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
191
    /// Partition Mode
192
    mode: StreamJoinPartitionMode,
193
    /// Cache holding plan properties like equivalences, output partitioning etc.
194
    cache: PlanProperties,
195
}
196
197
impl SymmetricHashJoinExec {
198
    /// Tries to create a new [SymmetricHashJoinExec].
199
    /// # Error
200
    /// This function errors when:
201
    /// - It is not possible to join the left and right sides on keys `on`, or
202
    /// - It fails to construct `SortedFilterExpr`s, or
203
    /// - It fails to create the [ExprIntervalGraph].
204
    #[allow(clippy::too_many_arguments)]
205
333
    pub fn try_new(
206
333
        left: Arc<dyn ExecutionPlan>,
207
333
        right: Arc<dyn ExecutionPlan>,
208
333
        on: JoinOn,
209
333
        filter: Option<JoinFilter>,
210
333
        join_type: &JoinType,
211
333
        null_equals_null: bool,
212
333
        left_sort_exprs: Option<Vec<PhysicalSortExpr>>,
213
333
        right_sort_exprs: Option<Vec<PhysicalSortExpr>>,
214
333
        mode: StreamJoinPartitionMode,
215
333
    ) -> Result<Self> {
216
333
        let left_schema = left.schema();
217
333
        let right_schema = right.schema();
218
333
219
333
        // Error out if no "on" constraints are given:
220
333
        if on.is_empty() {
221
0
            return plan_err!(
222
0
                "On constraints in SymmetricHashJoinExec should be non-empty"
223
0
            );
224
333
        }
225
333
226
333
        // Check if the join is valid with the given on constraints:
227
333
        check_join_is_valid(&left_schema, &right_schema, &on)
?0
;
228
229
        // Build the join schema from the left and right schemas:
230
333
        let (schema, column_indices) =
231
333
            build_join_schema(&left_schema, &right_schema, join_type);
232
333
233
333
        // Initialize the random state for the join operation:
234
333
        let random_state = RandomState::with_seeds(0, 0, 0, 0);
235
333
        let schema = Arc::new(schema);
236
333
        let cache =
237
333
            Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on);
238
333
        Ok(SymmetricHashJoinExec {
239
333
            left,
240
333
            right,
241
333
            on,
242
333
            filter,
243
333
            join_type: *join_type,
244
333
            random_state,
245
333
            metrics: ExecutionPlanMetricsSet::new(),
246
333
            column_indices,
247
333
            null_equals_null,
248
333
            left_sort_exprs,
249
333
            right_sort_exprs,
250
333
            mode,
251
333
            cache,
252
333
        })
253
333
    }
254
255
    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
256
333
    fn compute_properties(
257
333
        left: &Arc<dyn ExecutionPlan>,
258
333
        right: &Arc<dyn ExecutionPlan>,
259
333
        schema: SchemaRef,
260
333
        join_type: JoinType,
261
333
        join_on: JoinOnRef,
262
333
    ) -> PlanProperties {
263
333
        // Calculate equivalence properties:
264
333
        let eq_properties = join_equivalence_properties(
265
333
            left.equivalence_properties().clone(),
266
333
            right.equivalence_properties().clone(),
267
333
            &join_type,
268
333
            schema,
269
333
            &[false, false],
270
333
            // Has alternating probe side
271
333
            None,
272
333
            join_on,
273
333
        );
274
333
275
333
        let output_partitioning =
276
333
            symmetric_join_output_partitioning(left, right, &join_type);
277
333
278
333
        // Determine execution mode:
279
333
        let mode = execution_mode_from_children([left, right]);
280
333
281
333
        PlanProperties::new(eq_properties, output_partitioning, mode)
282
333
    }
283
284
    /// left stream
285
0
    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
286
0
        &self.left
287
0
    }
288
289
    /// right stream
290
0
    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
291
0
        &self.right
292
0
    }
293
294
    /// Set of common columns used to join on
295
0
    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
296
0
        &self.on
297
0
    }
298
299
    /// Filters applied before join output
300
0
    pub fn filter(&self) -> Option<&JoinFilter> {
301
0
        self.filter.as_ref()
302
0
    }
303
304
    /// How the join is performed
305
0
    pub fn join_type(&self) -> &JoinType {
306
0
        &self.join_type
307
0
    }
308
309
    /// Get null_equals_null
310
0
    pub fn null_equals_null(&self) -> bool {
311
0
        self.null_equals_null
312
0
    }
313
314
    /// Get partition mode
315
0
    pub fn partition_mode(&self) -> StreamJoinPartitionMode {
316
0
        self.mode
317
0
    }
318
319
    /// Get left_sort_exprs
320
0
    pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
321
0
        self.left_sort_exprs.as_deref()
322
0
    }
323
324
    /// Get right_sort_exprs
325
0
    pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> {
326
0
        self.right_sort_exprs.as_deref()
327
0
    }
328
329
    /// Check if order information covers every column in the filter expression.
330
0
    pub fn check_if_order_information_available(&self) -> Result<bool> {
331
0
        if let Some(filter) = self.filter() {
332
0
            let left = self.left();
333
0
            if let Some(left_ordering) = left.output_ordering() {
334
0
                let right = self.right();
335
0
                if let Some(right_ordering) = right.output_ordering() {
336
0
                    let left_convertible = convert_sort_expr_with_filter_schema(
337
0
                        &JoinSide::Left,
338
0
                        filter,
339
0
                        &left.schema(),
340
0
                        &left_ordering[0],
341
0
                    )?
342
0
                    .is_some();
343
0
                    let right_convertible = convert_sort_expr_with_filter_schema(
344
0
                        &JoinSide::Right,
345
0
                        filter,
346
0
                        &right.schema(),
347
0
                        &right_ordering[0],
348
0
                    )?
349
0
                    .is_some();
350
0
                    return Ok(left_convertible && right_convertible);
351
0
                }
352
0
            }
353
0
        }
354
0
        Ok(false)
355
0
    }
356
}
357
358
impl DisplayAs for SymmetricHashJoinExec {
359
0
    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
360
0
        match t {
361
            DisplayFormatType::Default | DisplayFormatType::Verbose => {
362
0
                let display_filter = self.filter.as_ref().map_or_else(
363
0
                    || "".to_string(),
364
0
                    |f| format!(", filter={}", f.expression()),
365
0
                );
366
0
                let on = self
367
0
                    .on
368
0
                    .iter()
369
0
                    .map(|(c1, c2)| format!("({}, {})", c1, c2))
370
0
                    .collect::<Vec<String>>()
371
0
                    .join(", ");
372
0
                write!(
373
0
                    f,
374
0
                    "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}",
375
0
                    self.mode, self.join_type, on, display_filter
376
0
                )
377
0
            }
378
0
        }
379
0
    }
380
}
381
382
impl ExecutionPlan for SymmetricHashJoinExec {
383
0
    fn name(&self) -> &'static str {
384
0
        "SymmetricHashJoinExec"
385
0
    }
386
387
0
    fn as_any(&self) -> &dyn Any {
388
0
        self
389
0
    }
390
391
1.33k
    fn properties(&self) -> &PlanProperties {
392
1.33k
        &self.cache
393
1.33k
    }
394
395
0
    fn required_input_distribution(&self) -> Vec<Distribution> {
396
0
        match self.mode {
397
            StreamJoinPartitionMode::Partitioned => {
398
0
                let (left_expr, right_expr) = self
399
0
                    .on
400
0
                    .iter()
401
0
                    .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _))
402
0
                    .unzip();
403
0
                vec![
404
0
                    Distribution::HashPartitioned(left_expr),
405
0
                    Distribution::HashPartitioned(right_expr),
406
0
                ]
407
            }
408
            StreamJoinPartitionMode::SinglePartition => {
409
0
                vec![Distribution::SinglePartition, Distribution::SinglePartition]
410
            }
411
        }
412
0
    }
413
414
0
    fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
415
0
        vec![
416
0
            self.left_sort_exprs
417
0
                .as_ref()
418
0
                .map(PhysicalSortRequirement::from_sort_exprs),
419
0
            self.right_sort_exprs
420
0
                .as_ref()
421
0
                .map(PhysicalSortRequirement::from_sort_exprs),
422
0
        ]
423
0
    }
424
425
0
    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
426
0
        vec![&self.left, &self.right]
427
0
    }
428
429
0
    fn with_new_children(
430
0
        self: Arc<Self>,
431
0
        children: Vec<Arc<dyn ExecutionPlan>>,
432
0
    ) -> Result<Arc<dyn ExecutionPlan>> {
433
0
        Ok(Arc::new(SymmetricHashJoinExec::try_new(
434
0
            Arc::clone(&children[0]),
435
0
            Arc::clone(&children[1]),
436
0
            self.on.clone(),
437
0
            self.filter.clone(),
438
0
            &self.join_type,
439
0
            self.null_equals_null,
440
0
            self.left_sort_exprs.clone(),
441
0
            self.right_sort_exprs.clone(),
442
0
            self.mode,
443
0
        )?))
444
0
    }
445
446
0
    fn metrics(&self) -> Option<MetricsSet> {
447
0
        Some(self.metrics.clone_inner())
448
0
    }
449
450
0
    fn statistics(&self) -> Result<Statistics> {
451
0
        // TODO stats: it is not possible in general to know the output size of joins
452
0
        Ok(Statistics::new_unknown(&self.schema()))
453
0
    }
454
455
1.33k
    fn execute(
456
1.33k
        &self,
457
1.33k
        partition: usize,
458
1.33k
        context: Arc<TaskContext>,
459
1.33k
    ) -> Result<SendableRecordBatchStream> {
460
1.33k
        let left_partitions = self.left.output_partitioning().partition_count();
461
1.33k
        let right_partitions = self.right.output_partitioning().partition_count();
462
1.33k
        if left_partitions != right_partitions {
463
0
            return internal_err!(
464
0
                "Invalid SymmetricHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
465
0
                 consider using RepartitionExec"
466
0
            );
467
1.33k
        }
468
        // If `filter_state` and `filter` are both present, then calculate sorted filter expressions
469
        // for both sides, and build an expression graph.
470
1.33k
        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
471
1.33k
            match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) {
472
1.10k
                (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
473
1.10k
                    let (left, right, graph) = prepare_sorted_exprs(
474
1.10k
                        filter,
475
1.10k
                        &self.left,
476
1.10k
                        &self.right,
477
1.10k
                        left_sort_exprs,
478
1.10k
                        right_sort_exprs,
479
1.10k
                    )
?0
;
480
1.10k
                    (Some(left), Some(right), Some(graph))
481
                }
482
                // If `filter_state` or `filter` is not present, then return None for all three values:
483
224
                _ => (None, None, None),
484
            };
485
486
1.33k
        let (on_left, on_right) = self.on.iter().cloned().unzip();
487
1.33k
488
1.33k
        let left_side_joiner =
489
1.33k
            OneSideHashJoiner::new(JoinSide::Left, on_left, self.left.schema());
490
1.33k
        let right_side_joiner =
491
1.33k
            OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema());
492
493
1.33k
        let left_stream = self.left.execute(partition, Arc::clone(&context))
?0
;
494
495
1.33k
        let right_stream = self.right.execute(partition, Arc::clone(&context))
?0
;
496
497
1.33k
        let reservation = Arc::new(Mutex::new(
498
1.33k
            MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
499
1.33k
                .register(context.memory_pool()),
500
1.33k
        ));
501
1.33k
        if let Some(
g1.10k
) = graph.as_ref() {
502
1.10k
            reservation.lock().try_grow(g.size())
?0
;
503
224
        }
504
505
1.33k
        Ok(Box::pin(SymmetricHashJoinStream {
506
1.33k
            left_stream,
507
1.33k
            right_stream,
508
1.33k
            schema: self.schema(),
509
1.33k
            filter: self.filter.clone(),
510
1.33k
            join_type: self.join_type,
511
1.33k
            random_state: self.random_state.clone(),
512
1.33k
            left: left_side_joiner,
513
1.33k
            right: right_side_joiner,
514
1.33k
            column_indices: self.column_indices.clone(),
515
1.33k
            metrics: StreamJoinMetrics::new(partition, &self.metrics),
516
1.33k
            graph,
517
1.33k
            left_sorted_filter_expr,
518
1.33k
            right_sorted_filter_expr,
519
1.33k
            null_equals_null: self.null_equals_null,
520
1.33k
            state: SHJStreamState::PullRight,
521
1.33k
            reservation,
522
1.33k
        }))
523
1.33k
    }
524
}
525
526
/// A stream that issues [RecordBatch]es as they arrive from the right  of the join.
527
struct SymmetricHashJoinStream {
528
    /// Input streams
529
    left_stream: SendableRecordBatchStream,
530
    right_stream: SendableRecordBatchStream,
531
    /// Input schema
532
    schema: Arc<Schema>,
533
    /// join filter
534
    filter: Option<JoinFilter>,
535
    /// type of the join
536
    join_type: JoinType,
537
    // left hash joiner
538
    left: OneSideHashJoiner,
539
    /// right hash joiner
540
    right: OneSideHashJoiner,
541
    /// Information of index and left / right placement of columns
542
    column_indices: Vec<ColumnIndex>,
543
    // Expression graph for range pruning.
544
    graph: Option<ExprIntervalGraph>,
545
    // Left globally sorted filter expr
546
    left_sorted_filter_expr: Option<SortedFilterExpr>,
547
    // Right globally sorted filter expr
548
    right_sorted_filter_expr: Option<SortedFilterExpr>,
549
    /// Random state used for hashing initialization
550
    random_state: RandomState,
551
    /// If null_equals_null is true, null == null else null != null
552
    null_equals_null: bool,
553
    /// Metrics
554
    metrics: StreamJoinMetrics,
555
    /// Memory reservation
556
    reservation: SharedMemoryReservation,
557
    /// State machine for input execution
558
    state: SHJStreamState,
559
}
560
561
impl RecordBatchStream for SymmetricHashJoinStream {
562
0
    fn schema(&self) -> SchemaRef {
563
0
        Arc::clone(&self.schema)
564
0
    }
565
}
566
567
impl Stream for SymmetricHashJoinStream {
568
    type Item = Result<RecordBatch>;
569
570
4.78k
    fn poll_next(
571
4.78k
        mut self: std::pin::Pin<&mut Self>,
572
4.78k
        cx: &mut std::task::Context<'_>,
573
4.78k
    ) -> Poll<Option<Self::Item>> {
574
4.78k
        self.poll_next_impl(cx)
575
4.78k
    }
576
}
577
578
/// Determine the pruning length for `buffer`.
579
///
580
/// This function evaluates the build side filter expression, converts the
581
/// result into an array and determines the pruning length by performing a
582
/// binary search on the array.
583
///
584
/// # Arguments
585
///
586
/// * `buffer`: The record batch to be pruned.
587
/// * `build_side_filter_expr`: The filter expression on the build side used
588
///   to determine the pruning length.
589
///
590
/// # Returns
591
///
592
/// A [Result] object that contains the pruning length. The function will return
593
/// an error if
594
/// - there is an issue evaluating the build side filter expression;
595
/// - there is an issue converting the build side filter expression into an array
596
3.66k
fn determine_prune_length(
597
3.66k
    buffer: &RecordBatch,
598
3.66k
    build_side_filter_expr: &SortedFilterExpr,
599
3.66k
) -> Result<usize> {
600
3.66k
    let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr();
601
3.66k
    let interval = build_side_filter_expr.interval();
602
    // Evaluate the build side filter expression and convert it into an array
603
3.66k
    let batch_arr = origin_sorted_expr
604
3.66k
        .expr
605
3.66k
        .evaluate(buffer)
?0
606
3.66k
        .into_array(buffer.num_rows())
?0
;
607
608
    // Get the lower or upper interval based on the sort direction
609
3.66k
    let target = if origin_sorted_expr.options.descending {
610
839
        interval.upper().clone()
611
    } else {
612
2.82k
        interval.lower().clone()
613
    };
614
615
    // Perform binary search on the array to determine the length of the record batch to be pruned
616
3.66k
    bisect::<true>(&[batch_arr], &[target], &[origin_sorted_expr.options])
617
3.66k
}
618
619
/// This method determines if the result of the join should be produced in the final step or not.
620
///
621
/// # Arguments
622
///
623
/// * `build_side` - Enum indicating the side of the join used as the build side.
624
/// * `join_type` - Enum indicating the type of join to be performed.
625
///
626
/// # Returns
627
///
628
/// A boolean indicating whether the result of the join should be produced in the final step or not.
629
/// The result will be true if the build side is JoinSide::Left and the join type is one of
630
/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi.
631
/// If the build side is JoinSide::Right, the result will be true if the join type
632
/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi.
633
19.3k
fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool {
634
19.3k
    if build_side == JoinSide::Left {
635
4.98k
        matches!(
636
10.0k
            join_type,
637
            JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi
638
        )
639
    } else {
640
4.55k
        matches!(
641
9.22k
            join_type,
642
            JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi
643
        )
644
    }
645
19.3k
}
646
647
/// Calculate indices by join type.
648
///
649
/// This method returns a tuple of two arrays: build and probe indices.
650
/// The length of both arrays will be the same.
651
///
652
/// # Arguments
653
///
654
/// * `build_side`: Join side which defines the build side.
655
/// * `prune_length`: Length of the prune data.
656
/// * `visited_rows`: Hash set of visited rows of the build side.
657
/// * `deleted_offset`: Deleted offset of the build side.
658
/// * `join_type`: The type of join to be performed.
659
///
660
/// # Returns
661
///
662
/// A tuple of two arrays of primitive types representing the build and probe indices.
663
///
664
2.79k
fn calculate_indices_by_join_type<L: ArrowPrimitiveType, R: ArrowPrimitiveType>(
665
2.79k
    build_side: JoinSide,
666
2.79k
    prune_length: usize,
667
2.79k
    visited_rows: &HashSet<usize>,
668
2.79k
    deleted_offset: usize,
669
2.79k
    join_type: JoinType,
670
2.79k
) -> Result<(PrimitiveArray<L>, PrimitiveArray<R>)>
671
2.79k
where
672
2.79k
    NativeAdapter<L>: From<<L as ArrowPrimitiveType>::Native>,
673
2.79k
{
674
    // Store the result in a tuple
675
2.79k
    let 
result2.11k
= match (build_side, join_type) {
676
        // In the case of `Left` or `Right` join, or `Full` join, get the anti indices
677
        (JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
678
        | (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
679
        | (_, JoinType::Full) => {
680
1.42k
            let build_unmatched_indices =
681
1.42k
                get_pruning_anti_indices(prune_length, deleted_offset, visited_rows);
682
1.42k
            let mut builder =
683
1.42k
                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
684
1.42k
            builder.append_nulls(build_unmatched_indices.len());
685
1.42k
            let probe_indices = builder.finish();
686
1.42k
            (build_unmatched_indices, probe_indices)
687
        }
688
        // In the case of `LeftSemi` or `RightSemi` join, get the semi indices
689
        (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => {
690
683
            let build_unmatched_indices =
691
683
                get_pruning_semi_indices(prune_length, deleted_offset, visited_rows);
692
683
            let mut builder =
693
683
                PrimitiveBuilder::<R>::with_capacity(build_unmatched_indices.len());
694
683
            builder.append_nulls(build_unmatched_indices.len());
695
683
            let probe_indices = builder.finish();
696
683
            (build_unmatched_indices, probe_indices)
697
        }
698
        // The case of other join types is not considered
699
0
        _ => unreachable!(),
700
    };
701
2.11k
    Ok(result)
702
2.11k
}
703
704
/// This function produces unmatched record results based on the build side,
705
/// join type and other parameters.
706
///
707
/// The method uses first `prune_length` rows from the build side input buffer
708
/// to produce results.
709
///
710
/// # Arguments
711
///
712
/// * `output_schema` - The schema of the final output record batch.
713
/// * `prune_length` - The length of the determined prune length.
714
/// * `probe_schema` - The schema of the probe [RecordBatch].
715
/// * `join_type` - The type of join to be performed.
716
/// * `column_indices` - Indices of columns that are being joined.
717
///
718
/// # Returns
719
///
720
/// * `Option<RecordBatch>` - The final output record batch if required, otherwise [None].
721
9.44k
pub(crate) fn build_side_determined_results(
722
9.44k
    build_hash_joiner: &OneSideHashJoiner,
723
9.44k
    output_schema: &SchemaRef,
724
9.44k
    prune_length: usize,
725
9.44k
    probe_schema: SchemaRef,
726
9.44k
    join_type: JoinType,
727
9.44k
    column_indices: &[ColumnIndex],
728
9.44k
) -> Result<Option<RecordBatch>> {
729
9.44k
    // Check if we need to produce a result in the final output:
730
9.44k
    if prune_length > 0
731
5.52k
        && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type)
732
    {
733
        // Calculate the indices for build and probe sides based on join type and build side:
734
2.79k
        let (build_indices, probe_indices) = calculate_indices_by_join_type(
735
2.79k
            build_hash_joiner.build_side,
736
2.79k
            prune_length,
737
2.79k
            &build_hash_joiner.visited_rows,
738
2.79k
            build_hash_joiner.deleted_offset,
739
2.79k
            join_type,
740
2.79k
        )
?0
;
741
742
        // Create an empty probe record batch:
743
2.79k
        let empty_probe_batch = RecordBatch::new_empty(probe_schema);
744
2.79k
        // Build the final result from the indices of build and probe sides:
745
2.79k
        build_batch_from_indices(
746
2.79k
            output_schema.as_ref(),
747
2.79k
            &build_hash_joiner.input_buffer,
748
2.79k
            &empty_probe_batch,
749
2.79k
            &build_indices,
750
2.79k
            &probe_indices,
751
2.79k
            column_indices,
752
2.79k
            build_hash_joiner.build_side,
753
2.79k
        )
754
2.79k
        .map(|batch| (batch.num_rows() > 0).then_some(batch))
755
    } else {
756
        // If we don't need to produce a result, return None
757
6.65k
        Ok(None)
758
    }
759
9.44k
}
760
761
/// This method performs a join between the build side input buffer and the probe side batch.
762
///
763
/// # Arguments
764
///
765
/// * `build_hash_joiner` - Build side hash joiner
766
/// * `probe_hash_joiner` - Probe side hash joiner
767
/// * `schema` - A reference to the schema of the output record batch.
768
/// * `join_type` - The type of join to be performed.
769
/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
770
/// * `filter` - An optional filter on the join condition.
771
/// * `probe_batch` - The second record batch to be joined.
772
/// * `column_indices` - An array of columns to be selected for the result of the join.
773
/// * `random_state` - The random state for the join.
774
/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining.
775
///
776
/// # Returns
777
///
778
/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`.
779
/// If the join type is one of the above four, the function will return [None].
780
#[allow(clippy::too_many_arguments)]
781
8.11k
pub(crate) fn join_with_probe_batch(
782
8.11k
    build_hash_joiner: &mut OneSideHashJoiner,
783
8.11k
    probe_hash_joiner: &mut OneSideHashJoiner,
784
8.11k
    schema: &SchemaRef,
785
8.11k
    join_type: JoinType,
786
8.11k
    filter: Option<&JoinFilter>,
787
8.11k
    probe_batch: &RecordBatch,
788
8.11k
    column_indices: &[ColumnIndex],
789
8.11k
    random_state: &RandomState,
790
8.11k
    null_equals_null: bool,
791
8.11k
) -> Result<Option<RecordBatch>> {
792
8.11k
    if build_hash_joiner.input_buffer.num_rows() == 0 || 
probe_batch.num_rows() == 06.88k
{
793
1.23k
        return Ok(None);
794
6.88k
    }
795
6.88k
    let (build_indices, probe_indices) = lookup_join_hashmap(
796
6.88k
        &build_hash_joiner.hashmap,
797
6.88k
        &build_hash_joiner.input_buffer,
798
6.88k
        probe_batch,
799
6.88k
        &build_hash_joiner.on,
800
6.88k
        &probe_hash_joiner.on,
801
6.88k
        random_state,
802
6.88k
        null_equals_null,
803
6.88k
        &mut build_hash_joiner.hashes_buffer,
804
6.88k
        Some(build_hash_joiner.deleted_offset),
805
6.88k
    )
?0
;
806
807
6.88k
    let (build_indices, probe_indices) = if let Some(
filter6.73k
) = filter {
808
6.73k
        apply_join_filter_to_indices(
809
6.73k
            &build_hash_joiner.input_buffer,
810
6.73k
            probe_batch,
811
6.73k
            build_indices,
812
6.73k
            probe_indices,
813
6.73k
            filter,
814
6.73k
            build_hash_joiner.build_side,
815
6.73k
        )
?0
816
    } else {
817
152
        (build_indices, probe_indices)
818
    };
819
820
6.88k
    if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
821
3.48k
        record_visited_indices(
822
3.48k
            &mut build_hash_joiner.visited_rows,
823
3.48k
            build_hash_joiner.deleted_offset,
824
3.48k
            &build_indices,
825
3.48k
        );
826
3.48k
    }
3.40k
827
6.88k
    if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) {
828
3.48k
        record_visited_indices(
829
3.48k
            &mut probe_hash_joiner.visited_rows,
830
3.48k
            probe_hash_joiner.offset,
831
3.48k
            &probe_indices,
832
3.48k
        );
833
3.48k
    }
3.40k
834
3.48k
    if matches!(
835
6.88k
        join_type,
836
        JoinType::LeftAnti
837
            | JoinType::RightAnti
838
            | JoinType::LeftSemi
839
            | JoinType::RightSemi
840
    ) {
841
3.40k
        Ok(None)
842
    } else {
843
3.48k
        build_batch_from_indices(
844
3.48k
            schema,
845
3.48k
            &build_hash_joiner.input_buffer,
846
3.48k
            probe_batch,
847
3.48k
            &build_indices,
848
3.48k
            &probe_indices,
849
3.48k
            column_indices,
850
3.48k
            build_hash_joiner.build_side,
851
3.48k
        )
852
3.48k
        .map(|batch| (batch.num_rows() > 0).then_some(batch))
853
    }
854
8.11k
}
855
856
/// This method performs lookups against JoinHashMap by hash values of join-key columns, and handles potential
857
/// hash collisions.
858
///
859
/// # Arguments
860
///
861
/// * `build_hashmap` - hashmap collected from build side data.
862
/// * `build_batch` - Build side record batch.
863
/// * `probe_batch` - Probe side record batch.
864
/// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join.
865
/// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join.
866
/// * `random_state` - The random state for the join.
867
/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining.
868
/// * `hashes_buffer` - Buffer used for probe side keys hash calculation.
869
/// * `deleted_offset` - deleted offset for build side data.
870
///
871
/// # Returns
872
///
873
/// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side,
874
/// matched by join key columns.
875
#[allow(clippy::too_many_arguments)]
876
6.88k
fn lookup_join_hashmap(
877
6.88k
    build_hashmap: &PruningJoinHashMap,
878
6.88k
    build_batch: &RecordBatch,
879
6.88k
    probe_batch: &RecordBatch,
880
6.88k
    build_on: &[PhysicalExprRef],
881
6.88k
    probe_on: &[PhysicalExprRef],
882
6.88k
    random_state: &RandomState,
883
6.88k
    null_equals_null: bool,
884
6.88k
    hashes_buffer: &mut Vec<u64>,
885
6.88k
    deleted_offset: Option<usize>,
886
6.88k
) -> Result<(UInt64Array, UInt32Array)> {
887
6.88k
    let keys_values = probe_on
888
6.88k
        .iter()
889
6.88k
        .map(|c| c.evaluate(probe_batch)
?0
.into_array(probe_batch.num_rows()))
890
6.88k
        .collect::<Result<Vec<_>>>()
?0
;
891
6.88k
    let build_join_values = build_on
892
6.88k
        .iter()
893
6.88k
        .map(|c| c.evaluate(build_batch)
?0
.into_array(build_batch.num_rows()))
894
6.88k
        .collect::<Result<Vec<_>>>()
?0
;
895
896
6.88k
    hashes_buffer.clear();
897
6.88k
    hashes_buffer.resize(probe_batch.num_rows(), 0);
898
6.88k
    let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)
?0
;
899
900
    // As SymmetricHashJoin uses LIFO JoinHashMap, the chained list algorithm
901
    // will return build indices for each probe row in a reverse order as such:
902
    // Build Indices: [5, 4, 3]
903
    // Probe Indices: [1, 1, 1]
904
    //
905
    // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side.
906
    // Let's consider probe rows [0,1] as an example:
907
    //
908
    // When the probe iteration sequence is reversed, the following pairings can be derived:
909
    //
910
    // For probe row 1:
911
    //     (5, 1)
912
    //     (4, 1)
913
    //     (3, 1)
914
    //
915
    // For probe row 0:
916
    //     (5, 0)
917
    //     (4, 0)
918
    //     (3, 0)
919
    //
920
    // After reversing both sets of indices, we obtain reversed indices:
921
    //
922
    //     (3,0)
923
    //     (4,0)
924
    //     (5,0)
925
    //     (3,1)
926
    //     (4,1)
927
    //     (5,1)
928
    //
929
    // With this approach, the lexicographic order on both the probe side and the build side is preserved.
930
6.88k
    let (mut matched_probe, mut matched_build) = build_hashmap
931
6.88k
        .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset);
932
6.88k
933
6.88k
    matched_probe.reverse();
934
6.88k
    matched_build.reverse();
935
6.88k
936
6.88k
    let build_indices: UInt64Array = matched_build.into();
937
6.88k
    let probe_indices: UInt32Array = matched_probe.into();
938
939
6.88k
    let (build_indices, probe_indices) = equal_rows_arr(
940
6.88k
        &build_indices,
941
6.88k
        &probe_indices,
942
6.88k
        &build_join_values,
943
6.88k
        &keys_values,
944
6.88k
        null_equals_null,
945
6.88k
    )
?0
;
946
947
6.88k
    Ok((build_indices, probe_indices))
948
6.88k
}
949
950
pub struct OneSideHashJoiner {
951
    /// Build side
952
    build_side: JoinSide,
953
    /// Input record batch buffer
954
    pub input_buffer: RecordBatch,
955
    /// Columns from the side
956
    pub(crate) on: Vec<PhysicalExprRef>,
957
    /// Hashmap
958
    pub(crate) hashmap: PruningJoinHashMap,
959
    /// Reuse the hashes buffer
960
    pub(crate) hashes_buffer: Vec<u64>,
961
    /// Matched rows
962
    pub(crate) visited_rows: HashSet<usize>,
963
    /// Offset
964
    pub(crate) offset: usize,
965
    /// Deleted offset
966
    pub(crate) deleted_offset: usize,
967
}
968
969
impl OneSideHashJoiner {
970
16.2k
    pub fn size(&self) -> usize {
971
16.2k
        let mut size = 0;
972
16.2k
        size += std::mem::size_of_val(self);
973
16.2k
        size += std::mem::size_of_val(&self.build_side);
974
16.2k
        size += self.input_buffer.get_array_memory_size();
975
16.2k
        size += std::mem::size_of_val(&self.on);
976
16.2k
        size += self.hashmap.size();
977
16.2k
        size += self.hashes_buffer.capacity() * std::mem::size_of::<u64>();
978
16.2k
        size += self.visited_rows.capacity() * std::mem::size_of::<usize>();
979
16.2k
        size += std::mem::size_of_val(&self.offset);
980
16.2k
        size += std::mem::size_of_val(&self.deleted_offset);
981
16.2k
        size
982
16.2k
    }
983
2.66k
    pub fn new(
984
2.66k
        build_side: JoinSide,
985
2.66k
        on: Vec<PhysicalExprRef>,
986
2.66k
        schema: SchemaRef,
987
2.66k
    ) -> Self {
988
2.66k
        Self {
989
2.66k
            build_side,
990
2.66k
            input_buffer: RecordBatch::new_empty(schema),
991
2.66k
            on,
992
2.66k
            hashmap: PruningJoinHashMap::with_capacity(0),
993
2.66k
            hashes_buffer: vec![],
994
2.66k
            visited_rows: HashSet::new(),
995
2.66k
            offset: 0,
996
2.66k
            deleted_offset: 0,
997
2.66k
        }
998
2.66k
    }
999
1000
    /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch.
1001
    ///
1002
    /// # Arguments
1003
    ///
1004
    /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer
1005
    /// * `random_state` - The random state used to hash values
1006
    ///
1007
    /// # Returns
1008
    ///
1009
    /// Returns a [Result] encapsulating any intermediate errors.
1010
8.11k
    pub(crate) fn update_internal_state(
1011
8.11k
        &mut self,
1012
8.11k
        batch: &RecordBatch,
1013
8.11k
        random_state: &RandomState,
1014
8.11k
    ) -> Result<()> {
1015
        // Merge the incoming batch with the existing input buffer:
1016
8.11k
        self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])
?0
;
1017
        // Resize the hashes buffer to the number of rows in the incoming batch:
1018
8.11k
        self.hashes_buffer.resize(batch.num_rows(), 0);
1019
8.11k
        // Get allocation_info before adding the item
1020
8.11k
        // Update the hashmap with the join key values and hashes of the incoming batch:
1021
8.11k
        update_hash(
1022
8.11k
            &self.on,
1023
8.11k
            batch,
1024
8.11k
            &mut self.hashmap,
1025
8.11k
            self.offset,
1026
8.11k
            random_state,
1027
8.11k
            &mut self.hashes_buffer,
1028
8.11k
            self.deleted_offset,
1029
8.11k
            false,
1030
8.11k
        )
?0
;
1031
8.11k
        Ok(())
1032
8.11k
    }
1033
1034
    /// Calculate prune length.
1035
    ///
1036
    /// # Arguments
1037
    ///
1038
    /// * `build_side_sorted_filter_expr` - Build side mutable sorted filter expression..
1039
    /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression.
1040
    /// * `graph` - A mutable reference to the physical expression graph.
1041
    ///
1042
    /// # Returns
1043
    ///
1044
    /// A Result object that contains the pruning length.
1045
6.78k
    pub(crate) fn calculate_prune_length_with_probe_batch(
1046
6.78k
        &mut self,
1047
6.78k
        build_side_sorted_filter_expr: &mut SortedFilterExpr,
1048
6.78k
        probe_side_sorted_filter_expr: &mut SortedFilterExpr,
1049
6.78k
        graph: &mut ExprIntervalGraph,
1050
6.78k
    ) -> Result<usize> {
1051
6.78k
        // Return early if the input buffer is empty:
1052
6.78k
        if self.input_buffer.num_rows() == 0 {
1053
1.05k
            return Ok(0);
1054
5.72k
        }
1055
5.72k
        // Process the build and probe side sorted filter expressions if both are present:
1056
5.72k
        // Collect the sorted filter expressions into a vector of (node_index, interval) tuples:
1057
5.72k
        let mut filter_intervals = vec![];
1058
11.4k
        for expr in [
1059
5.72k
            &build_side_sorted_filter_expr,
1060
5.72k
            &probe_side_sorted_filter_expr,
1061
        ] {
1062
11.4k
            filter_intervals.push((expr.node_index(), expr.interval().clone()))
1063
        }
1064
        // Update the physical expression graph using the join filter intervals:
1065
5.72k
        graph.update_ranges(&mut filter_intervals, Interval::CERTAINLY_TRUE)
?0
;
1066
        // Extract the new join filter interval for the build side:
1067
5.72k
        let calculated_build_side_interval = filter_intervals.remove(0).1;
1068
5.72k
        // If the intervals have not changed, return early without pruning:
1069
5.72k
        if calculated_build_side_interval.eq(build_side_sorted_filter_expr.interval()) {
1070
2.06k
            return Ok(0);
1071
3.66k
        }
1072
3.66k
        // Update the build side interval and determine the pruning length:
1073
3.66k
        build_side_sorted_filter_expr.set_interval(calculated_build_side_interval);
1074
3.66k
1075
3.66k
        determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr)
1076
6.78k
    }
1077
1078
6.78k
    pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> {
1079
6.78k
        // Prune the hash values:
1080
6.78k
        self.hashmap.prune_hash_values(
1081
6.78k
            prune_length,
1082
6.78k
            self.deleted_offset as u64,
1083
6.78k
            HASHMAP_SHRINK_SCALE_FACTOR,
1084
6.78k
        );
1085
        // Remove pruned rows from the visited rows set:
1086
9.37k
        for row in 
self.deleted_offset..(self.deleted_offset + prune_length)6.78k
{
1087
9.37k
            self.visited_rows.remove(&row);
1088
9.37k
        }
1089
        // Update the input buffer after pruning:
1090
6.78k
        self.input_buffer = self
1091
6.78k
            .input_buffer
1092
6.78k
            .slice(prune_length, self.input_buffer.num_rows() - prune_length);
1093
6.78k
        // Increment the deleted offset:
1094
6.78k
        self.deleted_offset += prune_length;
1095
6.78k
        Ok(())
1096
6.78k
    }
1097
}
1098
1099
/// `SymmetricHashJoinStream` manages incremental join operations between two
1100
/// streams. Unlike traditional join approaches that need to scan one side of
1101
/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates
1102
/// more dynamic join operations by working with streams as they emit data. This
1103
/// approach allows for more efficient processing, particularly in scenarios
1104
/// where waiting for complete data materialization is not feasible or optimal.
1105
/// The trait provides a framework for handling various states of such a join
1106
/// process, ensuring that join logic is efficiently executed as data becomes
1107
/// available from either stream.
1108
///
1109
/// This implementation performs eager joins of data from two different asynchronous
1110
/// streams, typically referred to as left and right streams. The implementation
1111
/// provides a comprehensive set of methods to control and execute the join
1112
/// process, leveraging the states defined in `SHJStreamState`. Methods are
1113
/// primarily focused on asynchronously fetching data batches from each stream,
1114
/// processing them, and managing transitions between various states of the join.
1115
///
1116
/// This implementations use a state machine approach to navigate different
1117
/// stages of the join operation, handling data from both streams and determining
1118
/// when the join completes.
1119
///
1120
/// State Transitions:
1121
/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
1122
///   - In `fetch_next_from_left_stream`, when fetching a batch from the left stream:
1123
///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
1124
///       processing the batch.
1125
///     - On error (`Some(Err(e))`), the error is returned, and the state remains
1126
///       unchanged.
1127
///     - On no data (`None`), state changes to `LeftExhausted`, returning `Continue`
1128
///       to proceed with the join process.
1129
/// - From `PullRight` to `PullLeft` or `RightExhausted`:
1130
///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
1131
///     - If a batch is available, state changes to `PullLeft` for processing.
1132
///     - On error, the error is returned without changing the state.
1133
///     - If right stream is exhausted (`None`), state transitions to `RightExhausted`,
1134
///       with a `Continue` result.
1135
/// - Handling `RightExhausted` and `LeftExhausted`:
1136
///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios
1137
///     when streams are exhausted:
1138
///     - They attempt to continue processing with the other stream.
1139
///     - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`.
1140
/// - Transition to `BothExhausted { final_result: true }`:
1141
///   - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are
1142
///     exhausted, indicating completion of processing and availability of final results.
1143
impl SymmetricHashJoinStream {
1144
    /// Implements the main polling logic for the join stream.
1145
    ///
1146
    /// This method continuously checks the state of the join stream and
1147
    /// acts accordingly by delegating the handling to appropriate sub-methods
1148
    /// depending on the current state.
1149
    ///
1150
    /// # Arguments
1151
    ///
1152
    /// * `cx` - A context that facilitates cooperative non-blocking execution within a task.
1153
    ///
1154
    /// # Returns
1155
    ///
1156
    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a `RecordBatch` or None.
1157
4.78k
    fn poll_next_impl(
1158
4.78k
        &mut self,
1159
4.78k
        cx: &mut Context<'_>,
1160
4.78k
    ) -> Poll<Option<Result<RecordBatch>>> {
1161
        loop {
1162
14.1k
            return match self.state() {
1163
                SHJStreamState::PullRight => {
1164
5.61k
                    
handle_state!0
(
ready!341
(self.fetch_next_from_right_stream(cx)))
1165
                }
1166
                SHJStreamState::PullLeft => {
1167
4.40k
                    
handle_state!0
(
ready!326
(self.fetch_next_from_left_stream(cx)))
1168
                }
1169
                SHJStreamState::RightExhausted => {
1170
1.22k
                    
handle_state!0
(
ready!28
(self.handle_right_stream_end(cx)))
1171
                }
1172
                SHJStreamState::LeftExhausted => {
1173
226
                    
handle_state!0
(
ready!2
(self.handle_left_stream_end(cx)))
1174
                }
1175
                SHJStreamState::BothExhausted {
1176
                    final_result: false,
1177
                } => {
1178
1.33k
                    
handle_state!0
(self.prepare_for_final_results_after_exhaustion())
1179
                }
1180
1.33k
                SHJStreamState::BothExhausted { final_result: true } => Poll::Ready(None),
1181
            };
1182
        }
1183
4.78k
    }
1184
    /// Asynchronously pulls the next batch from the right stream.
1185
    ///
1186
    /// This default implementation checks for the next value in the right stream.
1187
    /// If a batch is found, the state is switched to `PullLeft`, and the batch handling
1188
    /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`.
1189
    ///
1190
    /// # Returns
1191
    ///
1192
    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1193
5.61k
    fn fetch_next_from_right_stream(
1194
5.61k
        &mut self,
1195
5.61k
        cx: &mut Context<'_>,
1196
5.61k
    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1197
5.61k
        match 
ready!341
(self.right_stream().poll_next_unpin(cx)) {
1198
4.08k
            Some(Ok(batch)) => {
1199
4.08k
                if batch.num_rows() == 0 {
1200
0
                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1201
4.08k
                }
1202
4.08k
                self.set_state(SHJStreamState::PullLeft);
1203
4.08k
                Poll::Ready(self.process_batch_from_right(batch))
1204
            }
1205
0
            Some(Err(e)) => Poll::Ready(Err(e)),
1206
            None => {
1207
1.19k
                self.set_state(SHJStreamState::RightExhausted);
1208
1.19k
                Poll::Ready(Ok(StatefulStreamResult::Continue))
1209
            }
1210
        }
1211
5.61k
    }
1212
1213
    /// Asynchronously pulls the next batch from the left stream.
1214
    ///
1215
    /// This default implementation checks for the next value in the left stream.
1216
    /// If a batch is found, the state is switched to `PullRight`, and the batch handling
1217
    /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`.
1218
    ///
1219
    /// # Returns
1220
    ///
1221
    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after pulling the batch.
1222
4.40k
    fn fetch_next_from_left_stream(
1223
4.40k
        &mut self,
1224
4.40k
        cx: &mut Context<'_>,
1225
4.40k
    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1226
4.40k
        match 
ready!326
(self.left_stream().poll_next_unpin(cx)) {
1227
3.94k
            Some(Ok(batch)) => {
1228
3.94k
                if batch.num_rows() == 0 {
1229
0
                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1230
3.94k
                }
1231
3.94k
                self.set_state(SHJStreamState::PullRight);
1232
3.94k
                Poll::Ready(self.process_batch_from_left(batch))
1233
            }
1234
0
            Some(Err(e)) => Poll::Ready(Err(e)),
1235
            None => {
1236
138
                self.set_state(SHJStreamState::LeftExhausted);
1237
138
                Poll::Ready(Ok(StatefulStreamResult::Continue))
1238
            }
1239
        }
1240
4.40k
    }
1241
1242
    /// Asynchronously handles the scenario when the right stream is exhausted.
1243
    ///
1244
    /// In this default implementation, when the right stream is exhausted, it attempts
1245
    /// to pull from the left stream. If a batch is found in the left stream, it delegates
1246
    /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set
1247
    /// to indicate both streams are exhausted without final results yet.
1248
    ///
1249
    /// # Returns
1250
    ///
1251
    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1252
1.22k
    fn handle_right_stream_end(
1253
1.22k
        &mut self,
1254
1.22k
        cx: &mut Context<'_>,
1255
1.22k
    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1256
1.22k
        match 
ready!28
(self.left_stream().poll_next_unpin(cx)) {
1257
3
            Some(Ok(batch)) => {
1258
3
                if batch.num_rows() == 0 {
1259
0
                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1260
3
                }
1261
3
                Poll::Ready(self.process_batch_after_right_end(batch))
1262
            }
1263
0
            Some(Err(e)) => Poll::Ready(Err(e)),
1264
            None => {
1265
1.19k
                self.set_state(SHJStreamState::BothExhausted {
1266
1.19k
                    final_result: false,
1267
1.19k
                });
1268
1.19k
                Poll::Ready(Ok(StatefulStreamResult::Continue))
1269
            }
1270
        }
1271
1.22k
    }
1272
1273
    /// Asynchronously handles the scenario when the left stream is exhausted.
1274
    ///
1275
    /// When the left stream is exhausted, this default
1276
    /// implementation tries to pull from the right stream and delegates the batch
1277
    /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state
1278
    /// is updated to indicate so.
1279
    ///
1280
    /// # Returns
1281
    ///
1282
    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after checking the exhaustion state.
1283
226
    fn handle_left_stream_end(
1284
226
        &mut self,
1285
226
        cx: &mut Context<'_>,
1286
226
    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
1287
226
        match 
ready!2
(self.right_stream().poll_next_unpin(cx)) {
1288
86
            Some(Ok(batch)) => {
1289
86
                if batch.num_rows() == 0 {
1290
0
                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
1291
86
                }
1292
86
                Poll::Ready(self.process_batch_after_left_end(batch))
1293
            }
1294
0
            Some(Err(e)) => Poll::Ready(Err(e)),
1295
            None => {
1296
138
                self.set_state(SHJStreamState::BothExhausted {
1297
138
                    final_result: false,
1298
138
                });
1299
138
                Poll::Ready(Ok(StatefulStreamResult::Continue))
1300
            }
1301
        }
1302
226
    }
1303
1304
    /// Handles the state when both streams are exhausted and final results are yet to be produced.
1305
    ///
1306
    /// This default implementation switches the state to indicate both streams are
1307
    /// exhausted with final results and then invokes the handling for this specific
1308
    /// scenario via `process_batches_before_finalization`.
1309
    ///
1310
    /// # Returns
1311
    ///
1312
    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state result after both streams are exhausted.
1313
1.33k
    fn prepare_for_final_results_after_exhaustion(
1314
1.33k
        &mut self,
1315
1.33k
    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1316
1.33k
        self.set_state(SHJStreamState::BothExhausted { final_result: true });
1317
1.33k
        self.process_batches_before_finalization()
1318
1.33k
    }
1319
1320
4.16k
    fn process_batch_from_right(
1321
4.16k
        &mut self,
1322
4.16k
        batch: RecordBatch,
1323
4.16k
    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1324
4.16k
        self.perform_join_for_given_side(batch, JoinSide::Right)
1325
4.16k
            .map(|maybe_batch| {
1326
4.16k
                if maybe_batch.is_some() {
1327
1.06k
                    StatefulStreamResult::Ready(maybe_batch)
1328
                } else {
1329
3.10k
                    StatefulStreamResult::Continue
1330
                }
1331
4.16k
            })
1332
4.16k
    }
1333
1334
3.94k
    fn process_batch_from_left(
1335
3.94k
        &mut self,
1336
3.94k
        batch: RecordBatch,
1337
3.94k
    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1338
3.94k
        self.perform_join_for_given_side(batch, JoinSide::Left)
1339
3.94k
            .map(|maybe_batch| {
1340
3.94k
                if maybe_batch.is_some() {
1341
1.03k
                    StatefulStreamResult::Ready(maybe_batch)
1342
                } else {
1343
2.91k
                    StatefulStreamResult::Continue
1344
                }
1345
3.94k
            })
1346
3.94k
    }
1347
1348
86
    fn process_batch_after_left_end(
1349
86
        &mut self,
1350
86
        right_batch: RecordBatch,
1351
86
    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1352
86
        self.process_batch_from_right(right_batch)
1353
86
    }
1354
1355
3
    fn process_batch_after_right_end(
1356
3
        &mut self,
1357
3
        left_batch: RecordBatch,
1358
3
    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1359
3
        self.process_batch_from_left(left_batch)
1360
3
    }
1361
1362
1.33k
    fn process_batches_before_finalization(
1363
1.33k
        &mut self,
1364
1.33k
    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
1365
        // Get the left side results:
1366
1.33k
        let left_result = build_side_determined_results(
1367
1.33k
            &self.left,
1368
1.33k
            &self.schema,
1369
1.33k
            self.left.input_buffer.num_rows(),
1370
1.33k
            self.right.input_buffer.schema(),
1371
1.33k
            self.join_type,
1372
1.33k
            &self.column_indices,
1373
1.33k
        )
?0
;
1374
        // Get the right side results:
1375
1.33k
        let right_result = build_side_determined_results(
1376
1.33k
            &self.right,
1377
1.33k
            &self.schema,
1378
1.33k
            self.right.input_buffer.num_rows(),
1379
1.33k
            self.left.input_buffer.schema(),
1380
1.33k
            self.join_type,
1381
1.33k
            &self.column_indices,
1382
1.33k
        )
?0
;
1383
1384
        // Combine the left and right results:
1385
1.33k
        let result = combine_two_batches(&self.schema, left_result, right_result)
?0
;
1386
1387
        // Update the metrics and return the result:
1388
1.33k
        if let Some(
batch657
) = &result {
1389
            // Update the metrics:
1390
657
            self.metrics.output_batches.add(1);
1391
657
            self.metrics.output_rows.add(batch.num_rows());
1392
657
            return Ok(StatefulStreamResult::Ready(result));
1393
675
        }
1394
675
        Ok(StatefulStreamResult::Continue)
1395
1.33k
    }
1396
1397
5.84k
    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
1398
5.84k
        &mut self.right_stream
1399
5.84k
    }
1400
1401
5.63k
    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
1402
5.63k
        &mut self.left_stream
1403
5.63k
    }
1404
1405
12.0k
    fn set_state(&mut self, state: SHJStreamState) {
1406
12.0k
        self.state = state;
1407
12.0k
    }
1408
1409
14.1k
    fn state(&mut self) -> SHJStreamState {
1410
14.1k
        self.state.clone()
1411
14.1k
    }
1412
1413
8.11k
    fn size(&self) -> usize {
1414
8.11k
        let mut size = 0;
1415
8.11k
        size += std::mem::size_of_val(&self.schema);
1416
8.11k
        size += std::mem::size_of_val(&self.filter);
1417
8.11k
        size += std::mem::size_of_val(&self.join_type);
1418
8.11k
        size += self.left.size();
1419
8.11k
        size += self.right.size();
1420
8.11k
        size += std::mem::size_of_val(&self.column_indices);
1421
8.11k
        size += self.graph.as_ref().map(|g| 
g.size()6.78k
).unwrap_or(0);
1422
8.11k
        size += std::mem::size_of_val(&self.left_sorted_filter_expr);
1423
8.11k
        size += std::mem::size_of_val(&self.right_sorted_filter_expr);
1424
8.11k
        size += std::mem::size_of_val(&self.random_state);
1425
8.11k
        size += std::mem::size_of_val(&self.null_equals_null);
1426
8.11k
        size += std::mem::size_of_val(&self.metrics);
1427
8.11k
        size
1428
8.11k
    }
1429
1430
    /// Performs a join operation for the specified `probe_side` (either left or right).
1431
    /// This function:
1432
    /// 1. Determines which side is the probe and which is the build side.
1433
    /// 2. Updates metrics based on the batch that was polled.
1434
    /// 3. Executes the join with the given `probe_batch`.
1435
    /// 4. Optionally computes anti-join results if all conditions are met.
1436
    /// 5. Combines the results and returns a combined batch or `None` if no batch was produced.
1437
8.11k
    fn perform_join_for_given_side(
1438
8.11k
        &mut self,
1439
8.11k
        probe_batch: RecordBatch,
1440
8.11k
        probe_side: JoinSide,
1441
8.11k
    ) -> Result<Option<RecordBatch>> {
1442
        let (
1443
8.11k
            probe_hash_joiner,
1444
8.11k
            build_hash_joiner,
1445
8.11k
            probe_side_sorted_filter_expr,
1446
8.11k
            build_side_sorted_filter_expr,
1447
8.11k
            probe_side_metrics,
1448
8.11k
        ) = if probe_side.eq(&JoinSide::Left) {
1449
3.94k
            (
1450
3.94k
                &mut self.left,
1451
3.94k
                &mut self.right,
1452
3.94k
                &mut self.left_sorted_filter_expr,
1453
3.94k
                &mut self.right_sorted_filter_expr,
1454
3.94k
                &mut self.metrics.left,
1455
3.94k
            )
1456
        } else {
1457
4.16k
            (
1458
4.16k
                &mut self.right,
1459
4.16k
                &mut self.left,
1460
4.16k
                &mut self.right_sorted_filter_expr,
1461
4.16k
                &mut self.left_sorted_filter_expr,
1462
4.16k
                &mut self.metrics.right,
1463
4.16k
            )
1464
        };
1465
        // Update the metrics for the stream that was polled:
1466
8.11k
        probe_side_metrics.input_batches.add(1);
1467
8.11k
        probe_side_metrics.input_rows.add(probe_batch.num_rows());
1468
8.11k
        // Update the internal state of the hash joiner for the build side:
1469
8.11k
        probe_hash_joiner.update_internal_state(&probe_batch, &self.random_state)
?0
;
1470
        // Join the two sides:
1471
8.11k
        let equal_result = join_with_probe_batch(
1472
8.11k
            build_hash_joiner,
1473
8.11k
            probe_hash_joiner,
1474
8.11k
            &self.schema,
1475
8.11k
            self.join_type,
1476
8.11k
            self.filter.as_ref(),
1477
8.11k
            &probe_batch,
1478
8.11k
            &self.column_indices,
1479
8.11k
            &self.random_state,
1480
8.11k
            self.null_equals_null,
1481
8.11k
        )
?0
;
1482
        // Increment the offset for the probe hash joiner:
1483
8.11k
        probe_hash_joiner.offset += probe_batch.num_rows();
1484
1485
8.11k
        let anti_result = if let (
1486
6.78k
            Some(build_side_sorted_filter_expr),
1487
6.78k
            Some(probe_side_sorted_filter_expr),
1488
6.78k
            Some(graph),
1489
        ) = (
1490
8.11k
            build_side_sorted_filter_expr.as_mut(),
1491
8.11k
            probe_side_sorted_filter_expr.as_mut(),
1492
8.11k
            self.graph.as_mut(),
1493
        ) {
1494
            // Calculate filter intervals:
1495
6.78k
            calculate_filter_expr_intervals(
1496
6.78k
                &build_hash_joiner.input_buffer,
1497
6.78k
                build_side_sorted_filter_expr,
1498
6.78k
                &probe_batch,
1499
6.78k
                probe_side_sorted_filter_expr,
1500
6.78k
            )
?0
;
1501
6.78k
            let prune_length = build_hash_joiner
1502
6.78k
                .calculate_prune_length_with_probe_batch(
1503
6.78k
                    build_side_sorted_filter_expr,
1504
6.78k
                    probe_side_sorted_filter_expr,
1505
6.78k
                    graph,
1506
6.78k
                )
?0
;
1507
6.78k
            let result = build_side_determined_results(
1508
6.78k
                build_hash_joiner,
1509
6.78k
                &self.schema,
1510
6.78k
                prune_length,
1511
6.78k
                probe_batch.schema(),
1512
6.78k
                self.join_type,
1513
6.78k
                &self.column_indices,
1514
6.78k
            )
?0
;
1515
6.78k
            build_hash_joiner.prune_internal_state(prune_length)
?0
;
1516
6.78k
            result
1517
        } else {
1518
1.33k
            None
1519
        };
1520
1521
        // Combine results:
1522
8.11k
        let result = combine_two_batches(&self.schema, equal_result, anti_result)
?0
;
1523
8.11k
        let capacity = self.size();
1524
8.11k
        self.metrics.stream_memory_usage.set(capacity);
1525
8.11k
        self.reservation.lock().try_resize(capacity)
?0
;
1526
        // Update the metrics if we have a batch; otherwise, continue the loop.
1527
8.11k
        if let Some(
batch2.10k
) = &result {
1528
2.10k
            self.metrics.output_batches.add(1);
1529
2.10k
            self.metrics.output_rows.add(batch.num_rows());
1530
6.01k
        }
1531
8.11k
        Ok(result)
1532
8.11k
    }
1533
}
1534
1535
/// Represents the various states of an symmetric hash join stream operation.
1536
///
1537
/// This enum is used to track the current state of streaming during a join
1538
/// operation. It provides indicators as to which side of the join needs to be
1539
/// pulled next or if one (or both) sides have been exhausted. This allows
1540
/// for efficient management of resources and optimal performance during the
1541
/// join process.
1542
#[derive(Clone, Debug)]
1543
pub enum SHJStreamState {
1544
    /// Indicates that the next step should pull from the right side of the join.
1545
    PullRight,
1546
1547
    /// Indicates that the next step should pull from the left side of the join.
1548
    PullLeft,
1549
1550
    /// State representing that the right side of the join has been fully processed.
1551
    RightExhausted,
1552
1553
    /// State representing that the left side of the join has been fully processed.
1554
    LeftExhausted,
1555
1556
    /// Represents a state where both sides of the join are exhausted.
1557
    ///
1558
    /// The `final_result` field indicates whether the join operation has
1559
    /// produced a final result or not.
1560
    BothExhausted { final_result: bool },
1561
}
1562
1563
#[cfg(test)]
1564
mod tests {
1565
    use std::collections::HashMap;
1566
    use std::sync::Mutex;
1567
1568
    use super::*;
1569
    use crate::joins::test_utils::{
1570
        build_sides_record_batches, compare_batches, complicated_filter,
1571
        create_memory_table, join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
1572
        join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
1573
        partitioned_sym_join_with_filter, split_record_batches,
1574
    };
1575
1576
    use arrow::compute::SortOptions;
1577
    use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
1578
    use datafusion_common::ScalarValue;
1579
    use datafusion_execution::config::SessionConfig;
1580
    use datafusion_expr::Operator;
1581
    use datafusion_physical_expr::expressions::{binary, col, lit, Column};
1582
1583
    use once_cell::sync::Lazy;
1584
    use rstest::*;
1585
1586
    const TABLE_SIZE: i32 = 30;
1587
1588
    type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size)
1589
    type TableValue = (Vec<RecordBatch>, Vec<RecordBatch>); // (left, right)
1590
1591
    // Cache for storing tables
1592
    static TABLE_CACHE: Lazy<Mutex<HashMap<TableKey, TableValue>>> =
1593
1
        Lazy::new(|| Mutex::new(HashMap::new()));
1594
1595
333
    fn get_or_create_table(
1596
333
        cardinality: (i32, i32),
1597
333
        batch_size: usize,
1598
333
    ) -> Result<TableValue> {
1599
333
        {
1600
333
            let cache = TABLE_CACHE.lock().unwrap();
1601
333
            if let Some(
table328
) = cache.get(&(cardinality.0, cardinality.1, batch_size)) {
1602
328
                return Ok(table.clone());
1603
5
            }
1604
        }
1605
1606
        // If not, create the table
1607
5
        let (left_batch, right_batch) =
1608
5
            build_sides_record_batches(TABLE_SIZE, cardinality)
?0
;
1609
1610
5
        let (left_partition, right_partition) = (
1611
5
            split_record_batches(&left_batch, batch_size)
?0
,
1612
5
            split_record_batches(&right_batch, batch_size)
?0
,
1613
        );
1614
1615
        // Lock the cache again and store the table
1616
5
        let mut cache = TABLE_CACHE.lock().unwrap();
1617
5
1618
5
        // Store the table in the cache
1619
5
        cache.insert(
1620
5
            (cardinality.0, cardinality.1, batch_size),
1621
5
            (left_partition.clone(), right_partition.clone()),
1622
5
        );
1623
5
1624
5
        Ok((left_partition, right_partition))
1625
333
    }
1626
1627
333
    pub async fn experiment(
1628
333
        left: Arc<dyn ExecutionPlan>,
1629
333
        right: Arc<dyn ExecutionPlan>,
1630
333
        filter: Option<JoinFilter>,
1631
333
        join_type: JoinType,
1632
333
        on: JoinOn,
1633
333
        task_ctx: Arc<TaskContext>,
1634
333
    ) -> Result<()> {
1635
333
        let first_batches = partitioned_sym_join_with_filter(
1636
333
            Arc::clone(&left),
1637
333
            Arc::clone(&right),
1638
333
            on.clone(),
1639
333
            filter.clone(),
1640
333
            &join_type,
1641
333
            false,
1642
333
            Arc::clone(&task_ctx),
1643
333
        )
1644
697
        .await
?0
;
1645
333
        let second_batches = partitioned_hash_join_with_filter(
1646
333
            left, right, on, filter, &join_type, false, task_ctx,
1647
333
        )
1648
2.28k
        .await
?0
;
1649
333
        compare_batches(&first_batches, &second_batches);
1650
333
        Ok(())
1651
333
    }
1652
1653
16
    #[rstest]
1654
    #[tokio::test(flavor = "multi_thread")]
1655
    async fn complex_join_all_one_ascending_numeric(
1656
        #[values(
1657
            JoinType::Inner,
1658
            JoinType::Left,
1659
            JoinType::Right,
1660
            JoinType::RightSemi,
1661
            JoinType::LeftSemi,
1662
            JoinType::LeftAnti,
1663
            JoinType::RightAnti,
1664
            JoinType::Full
1665
        )]
1666
        join_type: JoinType,
1667
        #[values(
1668
        (4, 5),
1669
        (12, 17),
1670
        )]
1671
        cardinality: (i32, i32),
1672
    ) -> Result<()> {
1673
        // a + b > c + 10 AND a + b < c + 100
1674
        let task_ctx = Arc::new(TaskContext::default());
1675
1676
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
1677
1678
        let left_schema = &left_partition[0].schema();
1679
        let right_schema = &right_partition[0].schema();
1680
1681
        let left_sorted = vec![PhysicalSortExpr {
1682
            expr: binary(
1683
                col("la1", left_schema)?,
1684
                Operator::Plus,
1685
                col("la2", left_schema)?,
1686
                left_schema,
1687
            )?,
1688
            options: SortOptions::default(),
1689
        }];
1690
        let right_sorted = vec![PhysicalSortExpr {
1691
            expr: col("ra1", right_schema)?,
1692
            options: SortOptions::default(),
1693
        }];
1694
        let (left, right) = create_memory_table(
1695
            left_partition,
1696
            right_partition,
1697
            vec![left_sorted],
1698
            vec![right_sorted],
1699
        )?;
1700
1701
        let on = vec![(
1702
            binary(
1703
                col("lc1", left_schema)?,
1704
                Operator::Plus,
1705
                lit(ScalarValue::Int32(Some(1))),
1706
                left_schema,
1707
            )?,
1708
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1709
        )];
1710
1711
        let intermediate_schema = Schema::new(vec![
1712
            Field::new("0", DataType::Int32, true),
1713
            Field::new("1", DataType::Int32, true),
1714
            Field::new("2", DataType::Int32, true),
1715
        ]);
1716
        let filter_expr = complicated_filter(&intermediate_schema)?;
1717
        let column_indices = vec![
1718
            ColumnIndex {
1719
                index: 0,
1720
                side: JoinSide::Left,
1721
            },
1722
            ColumnIndex {
1723
                index: 4,
1724
                side: JoinSide::Left,
1725
            },
1726
            ColumnIndex {
1727
                index: 0,
1728
                side: JoinSide::Right,
1729
            },
1730
        ];
1731
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
1732
1733
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1734
        Ok(())
1735
    }
1736
1737
48
    #[rstest]
1738
    #[tokio::test(flavor = "multi_thread")]
1739
    async fn join_all_one_ascending_numeric(
1740
        #[values(
1741
            JoinType::Inner,
1742
            JoinType::Left,
1743
            JoinType::Right,
1744
            JoinType::RightSemi,
1745
            JoinType::LeftSemi,
1746
            JoinType::LeftAnti,
1747
            JoinType::RightAnti,
1748
            JoinType::Full
1749
        )]
1750
        join_type: JoinType,
1751
        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1752
    ) -> Result<()> {
1753
        let task_ctx = Arc::new(TaskContext::default());
1754
        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1755
1756
        let left_schema = &left_partition[0].schema();
1757
        let right_schema = &right_partition[0].schema();
1758
1759
        let left_sorted = vec![PhysicalSortExpr {
1760
            expr: col("la1", left_schema)?,
1761
            options: SortOptions::default(),
1762
        }];
1763
        let right_sorted = vec![PhysicalSortExpr {
1764
            expr: col("ra1", right_schema)?,
1765
            options: SortOptions::default(),
1766
        }];
1767
        let (left, right) = create_memory_table(
1768
            left_partition,
1769
            right_partition,
1770
            vec![left_sorted],
1771
            vec![right_sorted],
1772
        )?;
1773
1774
        let on = vec![(
1775
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
1776
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1777
        )];
1778
1779
        let intermediate_schema = Schema::new(vec![
1780
            Field::new("left", DataType::Int32, true),
1781
            Field::new("right", DataType::Int32, true),
1782
        ]);
1783
        let filter_expr = join_expr_tests_fixture_i32(
1784
            case_expr,
1785
            col("left", &intermediate_schema)?,
1786
            col("right", &intermediate_schema)?,
1787
        );
1788
        let column_indices = vec![
1789
            ColumnIndex {
1790
                index: 0,
1791
                side: JoinSide::Left,
1792
            },
1793
            ColumnIndex {
1794
                index: 0,
1795
                side: JoinSide::Right,
1796
            },
1797
        ];
1798
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
1799
1800
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1801
        Ok(())
1802
    }
1803
1804
48
    #[rstest]
1805
    #[tokio::test(flavor = "multi_thread")]
1806
    async fn join_without_sort_information(
1807
        #[values(
1808
            JoinType::Inner,
1809
            JoinType::Left,
1810
            JoinType::Right,
1811
            JoinType::RightSemi,
1812
            JoinType::LeftSemi,
1813
            JoinType::LeftAnti,
1814
            JoinType::RightAnti,
1815
            JoinType::Full
1816
        )]
1817
        join_type: JoinType,
1818
        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1819
    ) -> Result<()> {
1820
        let task_ctx = Arc::new(TaskContext::default());
1821
        let (left_partition, right_partition) = get_or_create_table((4, 5), 8)?;
1822
1823
        let left_schema = &left_partition[0].schema();
1824
        let right_schema = &right_partition[0].schema();
1825
        let (left, right) =
1826
            create_memory_table(left_partition, right_partition, vec![], vec![])?;
1827
1828
        let on = vec![(
1829
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
1830
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1831
        )];
1832
1833
        let intermediate_schema = Schema::new(vec![
1834
            Field::new("left", DataType::Int32, true),
1835
            Field::new("right", DataType::Int32, true),
1836
        ]);
1837
        let filter_expr = join_expr_tests_fixture_i32(
1838
            case_expr,
1839
            col("left", &intermediate_schema)?,
1840
            col("right", &intermediate_schema)?,
1841
        );
1842
        let column_indices = vec![
1843
            ColumnIndex {
1844
                index: 5,
1845
                side: JoinSide::Left,
1846
            },
1847
            ColumnIndex {
1848
                index: 5,
1849
                side: JoinSide::Right,
1850
            },
1851
        ];
1852
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
1853
1854
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1855
        Ok(())
1856
    }
1857
1858
8
    #[rstest]
1859
    #[tokio::test(flavor = "multi_thread")]
1860
    async fn join_without_filter(
1861
        #[values(
1862
            JoinType::Inner,
1863
            JoinType::Left,
1864
            JoinType::Right,
1865
            JoinType::RightSemi,
1866
            JoinType::LeftSemi,
1867
            JoinType::LeftAnti,
1868
            JoinType::RightAnti,
1869
            JoinType::Full
1870
        )]
1871
        join_type: JoinType,
1872
    ) -> Result<()> {
1873
        let task_ctx = Arc::new(TaskContext::default());
1874
        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
1875
        let left_schema = &left_partition[0].schema();
1876
        let right_schema = &right_partition[0].schema();
1877
        let (left, right) =
1878
            create_memory_table(left_partition, right_partition, vec![], vec![])?;
1879
1880
        let on = vec![(
1881
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
1882
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1883
        )];
1884
        experiment(left, right, None, join_type, on, task_ctx).await?;
1885
        Ok(())
1886
    }
1887
1888
48
    #[rstest]
1889
    #[tokio::test(flavor = "multi_thread")]
1890
    async fn join_all_one_descending_numeric_particular(
1891
        #[values(
1892
            JoinType::Inner,
1893
            JoinType::Left,
1894
            JoinType::Right,
1895
            JoinType::RightSemi,
1896
            JoinType::LeftSemi,
1897
            JoinType::LeftAnti,
1898
            JoinType::RightAnti,
1899
            JoinType::Full
1900
        )]
1901
        join_type: JoinType,
1902
        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
1903
    ) -> Result<()> {
1904
        let task_ctx = Arc::new(TaskContext::default());
1905
        let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?;
1906
        let left_schema = &left_partition[0].schema();
1907
        let right_schema = &right_partition[0].schema();
1908
        let left_sorted = vec![PhysicalSortExpr {
1909
            expr: col("la1_des", left_schema)?,
1910
            options: SortOptions {
1911
                descending: true,
1912
                nulls_first: true,
1913
            },
1914
        }];
1915
        let right_sorted = vec![PhysicalSortExpr {
1916
            expr: col("ra1_des", right_schema)?,
1917
            options: SortOptions {
1918
                descending: true,
1919
                nulls_first: true,
1920
            },
1921
        }];
1922
        let (left, right) = create_memory_table(
1923
            left_partition,
1924
            right_partition,
1925
            vec![left_sorted],
1926
            vec![right_sorted],
1927
        )?;
1928
1929
        let on = vec![(
1930
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
1931
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
1932
        )];
1933
1934
        let intermediate_schema = Schema::new(vec![
1935
            Field::new("left", DataType::Int32, true),
1936
            Field::new("right", DataType::Int32, true),
1937
        ]);
1938
        let filter_expr = join_expr_tests_fixture_i32(
1939
            case_expr,
1940
            col("left", &intermediate_schema)?,
1941
            col("right", &intermediate_schema)?,
1942
        );
1943
        let column_indices = vec![
1944
            ColumnIndex {
1945
                index: 5,
1946
                side: JoinSide::Left,
1947
            },
1948
            ColumnIndex {
1949
                index: 5,
1950
                side: JoinSide::Right,
1951
            },
1952
        ];
1953
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
1954
1955
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
1956
        Ok(())
1957
    }
1958
1959
    #[tokio::test(flavor = "multi_thread")]
1960
1
    async fn build_null_columns_first() -> Result<()> {
1961
1
        let join_type = JoinType::Full;
1962
1
        let case_expr = 1;
1963
1
        let session_config = SessionConfig::new().with_repartition_joins(false);
1964
1
        let task_ctx = TaskContext::default().with_session_config(session_config);
1965
1
        let task_ctx = Arc::new(task_ctx);
1966
1
        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)
?0
;
1967
1
        let left_schema = &left_partition[0].schema();
1968
1
        let right_schema = &right_partition[0].schema();
1969
1
        let left_sorted = vec![PhysicalSortExpr {
1970
1
            expr: col("l_asc_null_first", left_schema)
?0
,
1971
1
            options: SortOptions {
1972
1
                descending: false,
1973
1
                nulls_first: true,
1974
1
            },
1975
1
        }];
1976
1
        let right_sorted = vec![PhysicalSortExpr {
1977
1
            expr: col("r_asc_null_first", right_schema)
?0
,
1978
1
            options: SortOptions {
1979
1
                descending: false,
1980
1
                nulls_first: true,
1981
1
            },
1982
1
        }];
1983
1
        let (left, right) = create_memory_table(
1984
1
            left_partition,
1985
1
            right_partition,
1986
1
            vec![left_sorted],
1987
1
            vec![right_sorted],
1988
1
        )
?0
;
1989
1
1990
1
        let on = vec![(
1991
1
            Arc::new(Column::new_with_schema("lc1", left_schema)
?0
) as _,
1992
1
            Arc::new(Column::new_with_schema("rc1", right_schema)
?0
) as _,
1993
1
        )];
1994
1
1995
1
        let intermediate_schema = Schema::new(vec![
1996
1
            Field::new("left", DataType::Int32, true),
1997
1
            Field::new("right", DataType::Int32, true),
1998
1
        ]);
1999
1
        let filter_expr = join_expr_tests_fixture_i32(
2000
1
            case_expr,
2001
1
            col("left", &intermediate_schema)
?0
,
2002
1
            col("right", &intermediate_schema)
?0
,
2003
1
        );
2004
1
        let column_indices = vec![
2005
1
            ColumnIndex {
2006
1
                index: 6,
2007
1
                side: JoinSide::Left,
2008
1
            },
2009
1
            ColumnIndex {
2010
1
                index: 6,
2011
1
                side: JoinSide::Right,
2012
1
            },
2013
1
        ];
2014
1
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2015
12
        experiment(left, right, Some(filter), join_type, on, task_ctx).await
?0
;
2016
1
        Ok(())
2017
1
    }
2018
2019
    #[tokio::test(flavor = "multi_thread")]
2020
1
    async fn build_null_columns_last() -> Result<()> {
2021
1
        let join_type = JoinType::Full;
2022
1
        let case_expr = 1;
2023
1
        let session_config = SessionConfig::new().with_repartition_joins(false);
2024
1
        let task_ctx = TaskContext::default().with_session_config(session_config);
2025
1
        let task_ctx = Arc::new(task_ctx);
2026
1
        let (left_partition, right_partition) = get_or_create_table((10, 11), 8)
?0
;
2027
1
2028
1
        let left_schema = &left_partition[0].schema();
2029
1
        let right_schema = &right_partition[0].schema();
2030
1
        let left_sorted = vec![PhysicalSortExpr {
2031
1
            expr: col("l_asc_null_last", left_schema)
?0
,
2032
1
            options: SortOptions {
2033
1
                descending: false,
2034
1
                nulls_first: false,
2035
1
            },
2036
1
        }];
2037
1
        let right_sorted = vec![PhysicalSortExpr {
2038
1
            expr: col("r_asc_null_last", right_schema)
?0
,
2039
1
            options: SortOptions {
2040
1
                descending: false,
2041
1
                nulls_first: false,
2042
1
            },
2043
1
        }];
2044
1
        let (left, right) = create_memory_table(
2045
1
            left_partition,
2046
1
            right_partition,
2047
1
            vec![left_sorted],
2048
1
            vec![right_sorted],
2049
1
        )
?0
;
2050
1
2051
1
        let on = vec![(
2052
1
            Arc::new(Column::new_with_schema("lc1", left_schema)
?0
) as _,
2053
1
            Arc::new(Column::new_with_schema("rc1", right_schema)
?0
) as _,
2054
1
        )];
2055
1
2056
1
        let intermediate_schema = Schema::new(vec![
2057
1
            Field::new("left", DataType::Int32, true),
2058
1
            Field::new("right", DataType::Int32, true),
2059
1
        ]);
2060
1
        let filter_expr = join_expr_tests_fixture_i32(
2061
1
            case_expr,
2062
1
            col("left", &intermediate_schema)
?0
,
2063
1
            col("right", &intermediate_schema)
?0
,
2064
1
        );
2065
1
        let column_indices = vec![
2066
1
            ColumnIndex {
2067
1
                index: 7,
2068
1
                side: JoinSide::Left,
2069
1
            },
2070
1
            ColumnIndex {
2071
1
                index: 7,
2072
1
                side: JoinSide::Right,
2073
1
            },
2074
1
        ];
2075
1
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2076
1
2077
12
        experiment(left, right, Some(filter), join_type, on, task_ctx).await
?0
;
2078
1
        Ok(())
2079
1
    }
2080
2081
    #[tokio::test(flavor = "multi_thread")]
2082
1
    async fn build_null_columns_first_descending() -> Result<()> {
2083
1
        let join_type = JoinType::Full;
2084
1
        let cardinality = (10, 11);
2085
1
        let case_expr = 1;
2086
1
        let session_config = SessionConfig::new().with_repartition_joins(false);
2087
1
        let task_ctx = TaskContext::default().with_session_config(session_config);
2088
1
        let task_ctx = Arc::new(task_ctx);
2089
1
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)
?0
;
2090
1
2091
1
        let left_schema = &left_partition[0].schema();
2092
1
        let right_schema = &right_partition[0].schema();
2093
1
        let left_sorted = vec![PhysicalSortExpr {
2094
1
            expr: col("l_desc_null_first", left_schema)
?0
,
2095
1
            options: SortOptions {
2096
1
                descending: true,
2097
1
                nulls_first: true,
2098
1
            },
2099
1
        }];
2100
1
        let right_sorted = vec![PhysicalSortExpr {
2101
1
            expr: col("r_desc_null_first", right_schema)
?0
,
2102
1
            options: SortOptions {
2103
1
                descending: true,
2104
1
                nulls_first: true,
2105
1
            },
2106
1
        }];
2107
1
        let (left, right) = create_memory_table(
2108
1
            left_partition,
2109
1
            right_partition,
2110
1
            vec![left_sorted],
2111
1
            vec![right_sorted],
2112
1
        )
?0
;
2113
1
2114
1
        let on = vec![(
2115
1
            Arc::new(Column::new_with_schema("lc1", left_schema)
?0
) as _,
2116
1
            Arc::new(Column::new_with_schema("rc1", right_schema)
?0
) as _,
2117
1
        )];
2118
1
2119
1
        let intermediate_schema = Schema::new(vec![
2120
1
            Field::new("left", DataType::Int32, true),
2121
1
            Field::new("right", DataType::Int32, true),
2122
1
        ]);
2123
1
        let filter_expr = join_expr_tests_fixture_i32(
2124
1
            case_expr,
2125
1
            col("left", &intermediate_schema)
?0
,
2126
1
            col("right", &intermediate_schema)
?0
,
2127
1
        );
2128
1
        let column_indices = vec![
2129
1
            ColumnIndex {
2130
1
                index: 8,
2131
1
                side: JoinSide::Left,
2132
1
            },
2133
1
            ColumnIndex {
2134
1
                index: 8,
2135
1
                side: JoinSide::Right,
2136
1
            },
2137
1
        ];
2138
1
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2139
1
2140
12
        experiment(left, right, Some(filter), join_type, on, task_ctx).await
?0
;
2141
1
        Ok(())
2142
1
    }
2143
2144
    #[tokio::test(flavor = "multi_thread")]
2145
1
    async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> {
2146
1
        let cardinality = (3, 4);
2147
1
        let join_type = JoinType::Full;
2148
1
2149
1
        // a + b > c + 10 AND a + b < c + 100
2150
1
        let session_config = SessionConfig::new().with_repartition_joins(false);
2151
1
        let task_ctx = TaskContext::default().with_session_config(session_config);
2152
1
        let task_ctx = Arc::new(task_ctx);
2153
1
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)
?0
;
2154
1
2155
1
        let left_schema = &left_partition[0].schema();
2156
1
        let right_schema = &right_partition[0].schema();
2157
1
        let left_sorted = vec![PhysicalSortExpr {
2158
1
            expr: col("la1", left_schema)
?0
,
2159
1
            options: SortOptions::default(),
2160
1
        }];
2161
1
2162
1
        let right_sorted = vec![PhysicalSortExpr {
2163
1
            expr: col("ra1", right_schema)
?0
,
2164
1
            options: SortOptions::default(),
2165
1
        }];
2166
1
        let (left, right) = create_memory_table(
2167
1
            left_partition,
2168
1
            right_partition,
2169
1
            vec![left_sorted],
2170
1
            vec![right_sorted],
2171
1
        )
?0
;
2172
1
2173
1
        let on = vec![(
2174
1
            Arc::new(Column::new_with_schema("lc1", left_schema)
?0
) as _,
2175
1
            Arc::new(Column::new_with_schema("rc1", right_schema)
?0
) as _,
2176
1
        )];
2177
1
2178
1
        let intermediate_schema = Schema::new(vec![
2179
1
            Field::new("0", DataType::Int32, true),
2180
1
            Field::new("1", DataType::Int32, true),
2181
1
            Field::new("2", DataType::Int32, true),
2182
1
        ]);
2183
1
        let filter_expr = complicated_filter(&intermediate_schema)
?0
;
2184
1
        let column_indices = vec![
2185
1
            ColumnIndex {
2186
1
                index: 0,
2187
1
                side: JoinSide::Left,
2188
1
            },
2189
1
            ColumnIndex {
2190
1
                index: 4,
2191
1
                side: JoinSide::Left,
2192
1
            },
2193
1
            ColumnIndex {
2194
1
                index: 0,
2195
1
                side: JoinSide::Right,
2196
1
            },
2197
1
        ];
2198
1
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2199
1
2200
8
        experiment(left, right, Some(filter), join_type, on, task_ctx).await
?0
;
2201
1
        Ok(())
2202
1
    }
2203
2204
    #[tokio::test(flavor = "multi_thread")]
2205
1
    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
2206
1
        let cardinality = (3, 4);
2207
1
        let join_type = JoinType::Full;
2208
1
2209
1
        // a + b > c + 10 AND a + b < c + 100
2210
1
        let config = SessionConfig::new().with_repartition_joins(false);
2211
1
        // let session_ctx = SessionContext::with_config(config);
2212
1
        // let task_ctx = session_ctx.task_ctx();
2213
1
        let task_ctx = Arc::new(TaskContext::default().with_session_config(config));
2214
1
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)
?0
;
2215
1
        let left_schema = &left_partition[0].schema();
2216
1
        let right_schema = &right_partition[0].schema();
2217
1
        let left_sorted = vec![
2218
1
            vec![PhysicalSortExpr {
2219
1
                expr: col("la1", left_schema)
?0
,
2220
1
                options: SortOptions::default(),
2221
1
            }],
2222
1
            vec![PhysicalSortExpr {
2223
1
                expr: col("la2", left_schema)
?0
,
2224
1
                options: SortOptions::default(),
2225
1
            }],
2226
1
        ];
2227
1
2228
1
        let right_sorted = vec![PhysicalSortExpr {
2229
1
            expr: col("ra1", right_schema)
?0
,
2230
1
            options: SortOptions::default(),
2231
1
        }];
2232
1
2233
1
        let (left, right) = create_memory_table(
2234
1
            left_partition,
2235
1
            right_partition,
2236
1
            left_sorted,
2237
1
            vec![right_sorted],
2238
1
        )
?0
;
2239
1
2240
1
        let on = vec![(
2241
1
            Arc::new(Column::new_with_schema("lc1", left_schema)
?0
) as _,
2242
1
            Arc::new(Column::new_with_schema("rc1", right_schema)
?0
) as _,
2243
1
        )];
2244
1
2245
1
        let intermediate_schema = Schema::new(vec![
2246
1
            Field::new("0", DataType::Int32, true),
2247
1
            Field::new("1", DataType::Int32, true),
2248
1
            Field::new("2", DataType::Int32, true),
2249
1
        ]);
2250
1
        let filter_expr = complicated_filter(&intermediate_schema)
?0
;
2251
1
        let column_indices = vec![
2252
1
            ColumnIndex {
2253
1
                index: 0,
2254
1
                side: JoinSide::Left,
2255
1
            },
2256
1
            ColumnIndex {
2257
1
                index: 4,
2258
1
                side: JoinSide::Left,
2259
1
            },
2260
1
            ColumnIndex {
2261
1
                index: 0,
2262
1
                side: JoinSide::Right,
2263
1
            },
2264
1
        ];
2265
1
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2266
1
2267
10
        experiment(left, right, Some(filter), join_type, on, task_ctx).await
?0
;
2268
1
        Ok(())
2269
1
    }
2270
2271
48
    #[rstest]
2272
    #[tokio::test(flavor = "multi_thread")]
2273
    async fn testing_with_temporal_columns(
2274
        #[values(
2275
            JoinType::Inner,
2276
            JoinType::Left,
2277
            JoinType::Right,
2278
            JoinType::RightSemi,
2279
            JoinType::LeftSemi,
2280
            JoinType::LeftAnti,
2281
            JoinType::RightAnti,
2282
            JoinType::Full
2283
        )]
2284
        join_type: JoinType,
2285
        #[values(
2286
            (4, 5),
2287
            (12, 17),
2288
        )]
2289
        cardinality: (i32, i32),
2290
        #[values(0, 1, 2)] case_expr: usize,
2291
    ) -> Result<()> {
2292
        let session_config = SessionConfig::new().with_repartition_joins(false);
2293
        let task_ctx = TaskContext::default().with_session_config(session_config);
2294
        let task_ctx = Arc::new(task_ctx);
2295
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2296
2297
        let left_schema = &left_partition[0].schema();
2298
        let right_schema = &right_partition[0].schema();
2299
        let on = vec![(
2300
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
2301
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
2302
        )];
2303
        let left_sorted = vec![PhysicalSortExpr {
2304
            expr: col("lt1", left_schema)?,
2305
            options: SortOptions {
2306
                descending: false,
2307
                nulls_first: true,
2308
            },
2309
        }];
2310
        let right_sorted = vec![PhysicalSortExpr {
2311
            expr: col("rt1", right_schema)?,
2312
            options: SortOptions {
2313
                descending: false,
2314
                nulls_first: true,
2315
            },
2316
        }];
2317
        let (left, right) = create_memory_table(
2318
            left_partition,
2319
            right_partition,
2320
            vec![left_sorted],
2321
            vec![right_sorted],
2322
        )?;
2323
        let intermediate_schema = Schema::new(vec![
2324
            Field::new(
2325
                "left",
2326
                DataType::Timestamp(TimeUnit::Millisecond, None),
2327
                false,
2328
            ),
2329
            Field::new(
2330
                "right",
2331
                DataType::Timestamp(TimeUnit::Millisecond, None),
2332
                false,
2333
            ),
2334
        ]);
2335
        let filter_expr = join_expr_tests_fixture_temporal(
2336
            case_expr,
2337
            col("left", &intermediate_schema)?,
2338
            col("right", &intermediate_schema)?,
2339
            &intermediate_schema,
2340
        )?;
2341
        let column_indices = vec![
2342
            ColumnIndex {
2343
                index: 3,
2344
                side: JoinSide::Left,
2345
            },
2346
            ColumnIndex {
2347
                index: 3,
2348
                side: JoinSide::Right,
2349
            },
2350
        ];
2351
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2352
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2353
        Ok(())
2354
    }
2355
2356
16
    #[rstest]
2357
    #[tokio::test(flavor = "multi_thread")]
2358
    async fn test_with_interval_columns(
2359
        #[values(
2360
            JoinType::Inner,
2361
            JoinType::Left,
2362
            JoinType::Right,
2363
            JoinType::RightSemi,
2364
            JoinType::LeftSemi,
2365
            JoinType::LeftAnti,
2366
            JoinType::RightAnti,
2367
            JoinType::Full
2368
        )]
2369
        join_type: JoinType,
2370
        #[values(
2371
            (4, 5),
2372
            (12, 17),
2373
        )]
2374
        cardinality: (i32, i32),
2375
    ) -> Result<()> {
2376
        let session_config = SessionConfig::new().with_repartition_joins(false);
2377
        let task_ctx = TaskContext::default().with_session_config(session_config);
2378
        let task_ctx = Arc::new(task_ctx);
2379
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2380
2381
        let left_schema = &left_partition[0].schema();
2382
        let right_schema = &right_partition[0].schema();
2383
        let on = vec![(
2384
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
2385
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
2386
        )];
2387
        let left_sorted = vec![PhysicalSortExpr {
2388
            expr: col("li1", left_schema)?,
2389
            options: SortOptions {
2390
                descending: false,
2391
                nulls_first: true,
2392
            },
2393
        }];
2394
        let right_sorted = vec![PhysicalSortExpr {
2395
            expr: col("ri1", right_schema)?,
2396
            options: SortOptions {
2397
                descending: false,
2398
                nulls_first: true,
2399
            },
2400
        }];
2401
        let (left, right) = create_memory_table(
2402
            left_partition,
2403
            right_partition,
2404
            vec![left_sorted],
2405
            vec![right_sorted],
2406
        )?;
2407
        let intermediate_schema = Schema::new(vec![
2408
            Field::new("left", DataType::Interval(IntervalUnit::DayTime), false),
2409
            Field::new("right", DataType::Interval(IntervalUnit::DayTime), false),
2410
        ]);
2411
        let filter_expr = join_expr_tests_fixture_temporal(
2412
            0,
2413
            col("left", &intermediate_schema)?,
2414
            col("right", &intermediate_schema)?,
2415
            &intermediate_schema,
2416
        )?;
2417
        let column_indices = vec![
2418
            ColumnIndex {
2419
                index: 9,
2420
                side: JoinSide::Left,
2421
            },
2422
            ColumnIndex {
2423
                index: 9,
2424
                side: JoinSide::Right,
2425
            },
2426
        ];
2427
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2428
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2429
2430
        Ok(())
2431
    }
2432
2433
96
    #[rstest]
2434
    #[tokio::test(flavor = "multi_thread")]
2435
    async fn testing_ascending_float_pruning(
2436
        #[values(
2437
            JoinType::Inner,
2438
            JoinType::Left,
2439
            JoinType::Right,
2440
            JoinType::RightSemi,
2441
            JoinType::LeftSemi,
2442
            JoinType::LeftAnti,
2443
            JoinType::RightAnti,
2444
            JoinType::Full
2445
        )]
2446
        join_type: JoinType,
2447
        #[values(
2448
            (4, 5),
2449
            (12, 17),
2450
        )]
2451
        cardinality: (i32, i32),
2452
        #[values(0, 1, 2, 3, 4, 5)] case_expr: usize,
2453
    ) -> Result<()> {
2454
        let session_config = SessionConfig::new().with_repartition_joins(false);
2455
        let task_ctx = TaskContext::default().with_session_config(session_config);
2456
        let task_ctx = Arc::new(task_ctx);
2457
        let (left_partition, right_partition) = get_or_create_table(cardinality, 8)?;
2458
2459
        let left_schema = &left_partition[0].schema();
2460
        let right_schema = &right_partition[0].schema();
2461
        let left_sorted = vec![PhysicalSortExpr {
2462
            expr: col("l_float", left_schema)?,
2463
            options: SortOptions::default(),
2464
        }];
2465
        let right_sorted = vec![PhysicalSortExpr {
2466
            expr: col("r_float", right_schema)?,
2467
            options: SortOptions::default(),
2468
        }];
2469
        let (left, right) = create_memory_table(
2470
            left_partition,
2471
            right_partition,
2472
            vec![left_sorted],
2473
            vec![right_sorted],
2474
        )?;
2475
2476
        let on = vec![(
2477
            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
2478
            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
2479
        )];
2480
2481
        let intermediate_schema = Schema::new(vec![
2482
            Field::new("left", DataType::Float64, true),
2483
            Field::new("right", DataType::Float64, true),
2484
        ]);
2485
        let filter_expr = join_expr_tests_fixture_f64(
2486
            case_expr,
2487
            col("left", &intermediate_schema)?,
2488
            col("right", &intermediate_schema)?,
2489
        );
2490
        let column_indices = vec![
2491
            ColumnIndex {
2492
                index: 10, // l_float
2493
                side: JoinSide::Left,
2494
            },
2495
            ColumnIndex {
2496
                index: 10, // r_float
2497
                side: JoinSide::Right,
2498
            },
2499
        ];
2500
        let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema);
2501
2502
        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
2503
        Ok(())
2504
    }
2505
}