Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/utils.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Join related functionality used both on logical and physical plans
19
20
use std::collections::HashSet;
21
use std::fmt::{self, Debug};
22
use std::future::Future;
23
use std::ops::{IndexMut, Range};
24
use std::sync::Arc;
25
use std::task::{Context, Poll};
26
27
use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
28
use crate::{
29
    ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics,
30
};
31
32
use arrow::array::{
33
    downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array,
34
    UInt32Builder, UInt64Array,
35
};
36
use arrow::compute;
37
use arrow::datatypes::{Field, Schema, SchemaBuilder, UInt32Type, UInt64Type};
38
use arrow::record_batch::{RecordBatch, RecordBatchOptions};
39
use arrow_array::builder::UInt64Builder;
40
use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray};
41
use arrow_buffer::ArrowNativeType;
42
use datafusion_common::cast::as_boolean_array;
43
use datafusion_common::stats::Precision;
44
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
45
use datafusion_common::{
46
    plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult,
47
};
48
use datafusion_expr::interval_arithmetic::Interval;
49
use datafusion_physical_expr::equivalence::add_offset_to_expr;
50
use datafusion_physical_expr::expressions::Column;
51
use datafusion_physical_expr::utils::{collect_columns, merge_vectors};
52
use datafusion_physical_expr::{
53
    LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
54
};
55
56
use futures::future::{BoxFuture, Shared};
57
use futures::{ready, FutureExt};
58
use hashbrown::raw::RawTable;
59
use parking_lot::Mutex;
60
61
/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value.
62
///
63
/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side,
64
/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value.
65
///
66
/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1
67
/// As the key is a hash value, we need to check possible hash collisions in the probe stage
68
/// During this stage it might be the case that a row is contained the same hashmap value,
69
/// but the values don't match. Those are checked in the `equal_rows_arr` method.
70
///
71
/// The indices (values) are stored in a separate chained list stored in the `Vec<u64>`.
72
///
73
/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value.
74
///
75
/// The chain can be followed until the value "0" has been reached, meaning the end of the list.
76
/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487)
77
///
78
/// # Example
79
///
80
/// ``` text
81
/// See the example below:
82
///
83
/// Insert (10,1)            <-- insert hash value 10 with row index 1
84
/// map:
85
/// ----------
86
/// | 10 | 2 |
87
/// ----------
88
/// next:
89
/// ---------------------
90
/// | 0 | 0 | 0 | 0 | 0 |
91
/// ---------------------
92
/// Insert (20,2)
93
/// map:
94
/// ----------
95
/// | 10 | 2 |
96
/// | 20 | 3 |
97
/// ----------
98
/// next:
99
/// ---------------------
100
/// | 0 | 0 | 0 | 0 | 0 |
101
/// ---------------------
102
/// Insert (10,3)           <-- collision! row index 3 has a hash value of 10 as well
103
/// map:
104
/// ----------
105
/// | 10 | 4 |
106
/// | 20 | 3 |
107
/// ----------
108
/// next:
109
/// ---------------------
110
/// | 0 | 0 | 0 | 2 | 0 |  <--- hash value 10 maps to 4,2 (which means indices values 3,1)
111
/// ---------------------
112
/// Insert (10,4)          <-- another collision! row index 4 ALSO has a hash value of 10
113
/// map:
114
/// ---------
115
/// | 10 | 5 |
116
/// | 20 | 3 |
117
/// ---------
118
/// next:
119
/// ---------------------
120
/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1)
121
/// ---------------------
122
/// ```
123
pub struct JoinHashMap {
124
    // Stores hash value to last row index
125
    map: RawTable<(u64, u64)>,
126
    // Stores indices in chained list data structure
127
    next: Vec<u64>,
128
}
129
130
impl JoinHashMap {
131
    #[cfg(test)]
132
1
    pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec<u64>) -> Self {
133
1
        Self { map, next }
134
1
    }
135
136
1.71k
    pub(crate) fn with_capacity(capacity: usize) -> Self {
137
1.71k
        JoinHashMap {
138
1.71k
            map: RawTable::with_capacity(capacity),
139
1.71k
            next: vec![0; capacity],
140
1.71k
        }
141
1.71k
    }
142
}
143
144
// Type of offsets for obtaining indices from JoinHashMap.
145
pub(crate) type JoinHashMapOffset = (usize, Option<u64>);
146
147
// Macro for traversing chained values with limit.
148
// Early returns in case of reaching output tuples limit.
149
macro_rules! chain_traverse {
150
    (
151
        $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident,
152
        $input_idx:ident, $chain_idx:ident, $deleted_offset:ident, $remaining_output:ident
153
    ) => {
154
        let mut i = $chain_idx - 1;
155
        loop {
156
            let match_row_idx = if let Some(offset) = $deleted_offset {
157
                // This arguments means that we prune the next index way before here.
158
                if i < offset as u64 {
159
                    // End of the list due to pruning
160
                    break;
161
                }
162
                i - offset as u64
163
            } else {
164
                i
165
            };
166
            $match_indices.push(match_row_idx);
167
            $input_indices.push($input_idx as u32);
168
            $remaining_output -= 1;
169
            // Follow the chain to get the next index value
170
            let next = $next_chain[match_row_idx as usize];
171
172
            if $remaining_output == 0 {
173
                // In case current input index is the last, and no more chain values left
174
                // returning None as whole input has been scanned
175
                let next_offset = if $input_idx == $hash_values.len() - 1 && next == 0 {
176
                    None
177
                } else {
178
                    Some(($input_idx, Some(next)))
179
                };
180
                return ($input_indices, $match_indices, next_offset);
181
            }
182
            if next == 0 {
183
                // end of list
184
                break;
185
            }
186
            i = next - 1;
187
        }
188
    };
189
}
190
191
// Trait defining methods that must be implemented by a hash map type to be used for joins.
192
pub trait JoinHashMapType {
193
    /// The type of list used to store the next list
194
    type NextType: IndexMut<usize, Output = u64>;
195
    /// Extend with zero
196
    fn extend_zero(&mut self, len: usize);
197
    /// Returns mutable references to the hash map and the next.
198
    fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType);
199
    /// Returns a reference to the hash map.
200
    fn get_map(&self) -> &RawTable<(u64, u64)>;
201
    /// Returns a reference to the next.
202
    fn get_list(&self) -> &Self::NextType;
203
204
    /// Updates hashmap from iterator of row indices & row hashes pairs.
205
12.4k
    fn update_from_iter<'a>(
206
12.4k
        &mut self,
207
12.4k
        iter: impl Iterator<Item = (usize, &'a u64)>,
208
12.4k
        deleted_offset: usize,
209
12.4k
    ) {
210
12.4k
        let (mut_map, mut_list) = self.get_mut();
211
43.8k
        for (
row, hash_value31.3k
) in iter {
212
31.3k
            let item = mut_map.get_mut(*hash_value, |(hash, _)| 
*hash_value == *hash19.0k
);
213
31.3k
            if let Some((_, 
index18.9k
)) = item {
214
18.9k
                // Already exists: add index to next array
215
18.9k
                let prev_index = *index;
216
18.9k
                // Store new value inside hashmap
217
18.9k
                *index = (row + 1) as u64;
218
18.9k
                // Update chained Vec at `row` with previous value
219
18.9k
                mut_list[row - deleted_offset] = prev_index;
220
18.9k
            } else {
221
12.4k
                mut_map.insert(
222
12.4k
                    *hash_value,
223
12.4k
                    // store the value + 1 as 0 value reserved for end of list
224
12.4k
                    (*hash_value, (row + 1) as u64),
225
12.4k
                    |(hash, _)| 
*hash1.83k
,
226
12.4k
                );
227
12.4k
                // chained list at `row` is already initialized with 0
228
12.4k
                // meaning end of list
229
12.4k
            }
230
        }
231
12.4k
    }
232
233
    /// Returns all pairs of row indices matched by hash.
234
    ///
235
    /// This method only compares hashes, so additional further check for actual values
236
    /// equality may be required.
237
6.88k
    fn get_matched_indices<'a>(
238
6.88k
        &self,
239
6.88k
        iter: impl Iterator<Item = (usize, &'a u64)>,
240
6.88k
        deleted_offset: Option<usize>,
241
6.88k
    ) -> (Vec<u32>, Vec<u64>) {
242
6.88k
        let mut input_indices = vec![];
243
6.88k
        let mut match_indices = vec![];
244
6.88k
245
6.88k
        let hash_map = self.get_map();
246
6.88k
        let next_chain = self.get_list();
247
23.8k
        for (
row_idx, hash_value16.9k
) in iter {
248
            // Get the hash and find it in the index
249
13.9k
            if let Some((_, index)) =
250
16.9k
                hash_map.get(*hash_value, |(hash, _)| 
*hash_value == *hash13.9k
)
251
            {
252
13.9k
                let mut i = *index - 1;
253
                loop {
254
35.1k
                    let 
match_row_idx32.5k
= if let Some(offset) = deleted_offset {
255
                        // This arguments means that we prune the next index way before here.
256
35.1k
                        if i < offset as u64 {
257
                            // End of the list due to pruning
258
2.65k
                            break;
259
32.5k
                        }
260
32.5k
                        i - offset as u64
261
                    } else {
262
0
                        i
263
                    };
264
32.5k
                    match_indices.push(match_row_idx);
265
32.5k
                    input_indices.push(row_idx as u32);
266
32.5k
                    // Follow the chain to get the next index value
267
32.5k
                    let next = next_chain[match_row_idx as usize];
268
32.5k
                    if next == 0 {
269
                        // end of list
270
11.2k
                        break;
271
21.2k
                    }
272
21.2k
                    i = next - 1;
273
                }
274
2.99k
            }
275
        }
276
277
6.88k
        (input_indices, match_indices)
278
6.88k
    }
279
280
    /// Matches hashes with taking limit and offset into account.
281
    /// Returns pairs of matched indices along with the starting point for next
282
    /// matching iteration (`None` if limit has not been reached).
283
    ///
284
    /// This method only compares hashes, so additional further check for actual values
285
    /// equality may be required.
286
5.16k
    fn get_matched_indices_with_limit_offset(
287
5.16k
        &self,
288
5.16k
        hash_values: &[u64],
289
5.16k
        deleted_offset: Option<usize>,
290
5.16k
        limit: usize,
291
5.16k
        offset: JoinHashMapOffset,
292
5.16k
    ) -> (Vec<u32>, Vec<u64>, Option<JoinHashMapOffset>) {
293
5.16k
        let mut input_indices = vec![];
294
5.16k
        let mut match_indices = vec![];
295
5.16k
296
5.16k
        let mut remaining_output = limit;
297
5.16k
298
5.16k
        let hash_map: &RawTable<(u64, u64)> = self.get_map();
299
5.16k
        let next_chain = self.get_list();
300
301
        // Calculate initial `hash_values` index before iterating
302
5.16k
        let 
to_skip4.97k
= match offset {
303
            // None `initial_next_idx` indicates that `initial_idx` processing has'n been started
304
4.55k
            (initial_idx, None) => initial_idx,
305
            // Zero `initial_next_idx` indicates that `initial_idx` has been processed during
306
            // previous iteration, and it should be skipped
307
251
            (initial_idx, Some(0)) => initial_idx + 1,
308
            // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`,
309
            // to start with the next index
310
364
            (initial_idx, Some(initial_next_idx)) => {
311
364
                chain_traverse!(
312
364
                    input_indices,
313
364
                    match_indices,
314
364
                    hash_values,
315
364
                    next_chain,
316
364
                    initial_idx,
317
364
                    initial_next_idx,
318
364
                    deleted_offset,
319
364
                    remaining_output
320
364
                );
321
322
171
                initial_idx + 1
323
            }
324
        };
325
326
4.97k
        let mut row_idx = to_skip;
327
11.4k
        for hash_value in &
hash_values[to_skip..]4.97k
{
328
8.97k
            if let Some((_, index)) =
329
11.4k
                hash_map.get(*hash_value, |(hash, _)| 
*hash_value == *hash9.03k
)
330
            {
331
8.97k
                chain_traverse!(
332
8.97k
                    input_indices,
333
8.97k
                    match_indices,
334
8.97k
                    hash_values,
335
8.97k
                    next_chain,
336
8.97k
                    row_idx,
337
8.97k
                    index,
338
8.97k
                    deleted_offset,
339
8.97k
                    remaining_output
340
8.97k
                );
341
2.51k
            }
342
11.0k
            row_idx += 1;
343
        }
344
345
4.49k
        (input_indices, match_indices, None)
346
5.16k
    }
347
}
348
349
/// Implementation of `JoinHashMapType` for `JoinHashMap`.
350
impl JoinHashMapType for JoinHashMap {
351
    type NextType = Vec<u64>;
352
353
    // Void implementation
354
4.30k
    fn extend_zero(&mut self, _: usize) {}
355
356
    /// Get mutable references to the hash map and the next.
357
4.30k
    fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) {
358
4.30k
        (&mut self.map, &mut self.next)
359
4.30k
    }
360
361
    /// Get a reference to the hash map.
362
5.16k
    fn get_map(&self) -> &RawTable<(u64, u64)> {
363
5.16k
        &self.map
364
5.16k
    }
365
366
    /// Get a reference to the next.
367
5.16k
    fn get_list(&self) -> &Self::NextType {
368
5.16k
        &self.next
369
5.16k
    }
370
}
371
372
impl fmt::Debug for JoinHashMap {
373
0
    fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
374
0
        Ok(())
375
0
    }
376
}
377
378
/// The on clause of the join, as vector of (left, right) columns.
379
pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>;
380
/// Reference for JoinOn.
381
pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)];
382
383
/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
384
/// They are valid whenever their columns' intersection equals the set `on`
385
1.14k
pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
386
1.14k
    let left: HashSet<Column> = left
387
1.14k
        .fields()
388
1.14k
        .iter()
389
1.14k
        .enumerate()
390
8.76k
        .map(|(idx, f)| Column::new(f.name(), idx))
391
1.14k
        .collect();
392
1.14k
    let right: HashSet<Column> = right
393
1.14k
        .fields()
394
1.14k
        .iter()
395
1.14k
        .enumerate()
396
8.76k
        .map(|(idx, f)| Column::new(f.name(), idx))
397
1.14k
        .collect();
398
1.14k
399
1.14k
    check_join_set_is_valid(&left, &right, on)
400
1.14k
}
401
402
/// Checks whether the sets left, right and on compose a valid join.
403
/// They are valid whenever their intersection equals the set `on`
404
1.15k
fn check_join_set_is_valid(
405
1.15k
    left: &HashSet<Column>,
406
1.15k
    right: &HashSet<Column>,
407
1.15k
    on: &[(PhysicalExprRef, PhysicalExprRef)],
408
1.15k
) -> Result<()> {
409
1.15k
    let on_left = &on
410
1.15k
        .iter()
411
1.15k
        .flat_map(|on| 
collect_columns(&on.0)1.11k
)
412
1.15k
        .collect::<HashSet<_>>();
413
1.15k
    let left_missing = on_left.difference(left).collect::<HashSet<_>>();
414
1.15k
415
1.15k
    let on_right = &on
416
1.15k
        .iter()
417
1.15k
        .flat_map(|on| 
collect_columns(&on.1)1.11k
)
418
1.15k
        .collect::<HashSet<_>>();
419
1.15k
    let right_missing = on_right.difference(right).collect::<HashSet<_>>();
420
1.15k
421
1.15k
    if !left_missing.is_empty() | !right_missing.is_empty() {
422
2
        return plan_err!(
423
2
            "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}"
424
2
        );
425
1.15k
    };
426
1.15k
427
1.15k
    Ok(())
428
1.15k
}
429
430
/// Adjust the right out partitioning to new Column Index
431
321
pub fn adjust_right_output_partitioning(
432
321
    right_partitioning: &Partitioning,
433
321
    left_columns_len: usize,
434
321
) -> Partitioning {
435
321
    match right_partitioning {
436
176
        Partitioning::Hash(exprs, size) => {
437
176
            let new_exprs = exprs
438
176
                .iter()
439
176
                .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len))
440
176
                .collect();
441
176
            Partitioning::Hash(new_exprs, *size)
442
        }
443
145
        result => result.clone(),
444
    }
445
321
}
446
447
/// Replaces the right column (first index in the `on_column` tuple) with
448
/// the left column (zeroth index in the tuple) inside `right_ordering`.
449
2
fn replace_on_columns_of_right_ordering(
450
2
    on_columns: &[(PhysicalExprRef, PhysicalExprRef)],
451
2
    right_ordering: &mut [PhysicalSortExpr],
452
2
) -> Result<()> {
453
4
    for (
left_col, right_col2
) in on_columns {
454
4
        for item in 
right_ordering.iter_mut()2
{
455
4
            let new_expr = Arc::clone(&item.expr)
456
4
                .transform(|e| {
457
4
                    if e.eq(right_col) {
458
0
                        Ok(Transformed::yes(Arc::clone(left_col)))
459
                    } else {
460
4
                        Ok(Transformed::no(e))
461
                    }
462
4
                })
463
4
                .data()
?0
;
464
4
            item.expr = new_expr;
465
        }
466
    }
467
2
    Ok(())
468
2
}
469
470
2
fn offset_ordering(
471
2
    ordering: LexOrderingRef,
472
2
    join_type: &JoinType,
473
2
    offset: usize,
474
2
) -> Vec<PhysicalSortExpr> {
475
2
    match join_type {
476
        // In the case below, right ordering should be offsetted with the left
477
        // side length, since we append the right table to the left table.
478
2
        JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering
479
2
            .iter()
480
4
            .map(|sort_expr| PhysicalSortExpr {
481
4
                expr: add_offset_to_expr(Arc::clone(&sort_expr.expr), offset),
482
4
                options: sort_expr.options,
483
4
            })
484
2
            .collect(),
485
0
        _ => ordering.to_vec(),
486
    }
487
2
}
488
489
/// Calculate the output ordering of a given join operation.
490
2
pub fn calculate_join_output_ordering(
491
2
    left_ordering: LexOrderingRef,
492
2
    right_ordering: LexOrderingRef,
493
2
    join_type: JoinType,
494
2
    on_columns: &[(PhysicalExprRef, PhysicalExprRef)],
495
2
    left_columns_len: usize,
496
2
    maintains_input_order: &[bool],
497
2
    probe_side: Option<JoinSide>,
498
2
) -> Option<LexOrdering> {
499
2
    let output_ordering = match maintains_input_order {
500
2
        [true, false] => {
501
            // Special case, we can prefix ordering of right side with the ordering of left side.
502
1
            if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) {
503
1
                replace_on_columns_of_right_ordering(
504
1
                    on_columns,
505
1
                    &mut right_ordering.to_vec(),
506
1
                )
507
1
                .ok()
?0
;
508
1
                merge_vectors(
509
1
                    left_ordering,
510
1
                    &offset_ordering(right_ordering, &join_type, left_columns_len),
511
1
                )
512
            } else {
513
0
                left_ordering.to_vec()
514
            }
515
        }
516
        [false, true] => {
517
            // Special case, we can prefix ordering of left side with the ordering of right side.
518
1
            if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) {
519
1
                replace_on_columns_of_right_ordering(
520
1
                    on_columns,
521
1
                    &mut right_ordering.to_vec(),
522
1
                )
523
1
                .ok()
?0
;
524
1
                merge_vectors(
525
1
                    &offset_ordering(right_ordering, &join_type, left_columns_len),
526
1
                    left_ordering,
527
1
                )
528
            } else {
529
0
                offset_ordering(right_ordering, &join_type, left_columns_len)
530
            }
531
        }
532
        // Doesn't maintain ordering, output ordering is None.
533
0
        [false, false] => return None,
534
0
        [true, true] => unreachable!("Cannot maintain ordering of both sides"),
535
0
        _ => unreachable!("Join operators can not have more than two children"),
536
    };
537
2
    (!output_ordering.is_empty()).then_some(output_ordering)
538
2
}
539
540
/// Information about the index and placement (left or right) of the columns
541
#[derive(Debug, Clone, PartialEq)]
542
pub struct ColumnIndex {
543
    /// Index of the column
544
    pub index: usize,
545
    /// Whether the column is at the left or right side
546
    pub side: JoinSide,
547
}
548
549
/// Filter applied before join output
550
#[derive(Debug, Clone)]
551
pub struct JoinFilter {
552
    /// Filter expression
553
    expression: Arc<dyn PhysicalExpr>,
554
    /// Column indices required to construct intermediate batch for filtering
555
    column_indices: Vec<ColumnIndex>,
556
    /// Physical schema of intermediate batch
557
    schema: Schema,
558
}
559
560
impl JoinFilter {
561
    /// Creates new JoinFilter
562
433
    pub fn new(
563
433
        expression: Arc<dyn PhysicalExpr>,
564
433
        column_indices: Vec<ColumnIndex>,
565
433
        schema: Schema,
566
433
    ) -> JoinFilter {
567
433
        JoinFilter {
568
433
            expression,
569
433
            column_indices,
570
433
            schema,
571
433
        }
572
433
    }
573
574
    /// Helper for building ColumnIndex vector from left and right indices
575
0
    pub fn build_column_indices(
576
0
        left_indices: Vec<usize>,
577
0
        right_indices: Vec<usize>,
578
0
    ) -> Vec<ColumnIndex> {
579
0
        left_indices
580
0
            .into_iter()
581
0
            .map(|i| ColumnIndex {
582
0
                index: i,
583
0
                side: JoinSide::Left,
584
0
            })
585
0
            .chain(right_indices.into_iter().map(|i| ColumnIndex {
586
0
                index: i,
587
0
                side: JoinSide::Right,
588
0
            }))
589
0
            .collect()
590
0
    }
591
592
    /// Filter expression
593
25.5k
    pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
594
25.5k
        &self.expression
595
25.5k
    }
596
597
    /// Column indices for intermediate batch creation
598
24.4k
    pub fn column_indices(&self) -> &[ColumnIndex] {
599
24.4k
        &self.column_indices
600
24.4k
    }
601
602
    /// Intermediate batch schema
603
27.7k
    pub fn schema(&self) -> &Schema {
604
27.7k
        &self.schema
605
27.7k
    }
606
}
607
608
/// Returns the output field given the input field. Outer joins may
609
/// insert nulls even if the input was not null
610
///
611
9.11k
fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> Field {
612
9.11k
    let force_nullable = match join_type {
613
2.35k
        JoinType::Inner => false,
614
2.19k
        JoinType::Left => !is_left, // right input is padded with nulls
615
2.19k
        JoinType::Right => is_left, // left input is padded with nulls
616
2.38k
        JoinType::Full => true,     // both inputs can be padded with nulls
617
0
        JoinType::LeftSemi => false, // doesn't introduce nulls
618
0
        JoinType::RightSemi => false, // doesn't introduce nulls
619
0
        JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??)
620
0
        JoinType::RightAnti => false, // doesn't introduce nulls (or can it??)
621
    };
622
623
9.11k
    if force_nullable {
624
4.57k
        old_field.clone().with_nullable(true)
625
    } else {
626
4.54k
        old_field.clone()
627
    }
628
9.11k
}
629
630
/// Creates a schema for a join operation.
631
/// The fields from the left side are first
632
1.16k
pub fn build_join_schema(
633
1.16k
    left: &Schema,
634
1.16k
    right: &Schema,
635
1.16k
    join_type: &JoinType,
636
1.16k
) -> (Schema, Vec<ColumnIndex>) {
637
1.16k
    let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
638
        JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
639
631
            let left_fields = left
640
631
                .fields()
641
631
                .iter()
642
4.55k
                .map(|f| output_join_field(f, join_type, true))
643
631
                .enumerate()
644
4.55k
                .map(|(index, f)| {
645
4.55k
                    (
646
4.55k
                        f,
647
4.55k
                        ColumnIndex {
648
4.55k
                            index,
649
4.55k
                            side: JoinSide::Left,
650
4.55k
                        },
651
4.55k
                    )
652
4.55k
                });
653
631
            let right_fields = right
654
631
                .fields()
655
631
                .iter()
656
4.55k
                .map(|f| output_join_field(f, join_type, false))
657
631
                .enumerate()
658
4.55k
                .map(|(index, f)| {
659
4.55k
                    (
660
4.55k
                        f,
661
4.55k
                        ColumnIndex {
662
4.55k
                            index,
663
4.55k
                            side: JoinSide::Right,
664
4.55k
                        },
665
4.55k
                    )
666
4.55k
                });
667
631
668
631
            // left then right
669
631
            left_fields.chain(right_fields).unzip()
670
        }
671
268
        JoinType::LeftSemi | JoinType::LeftAnti => left
672
268
            .fields()
673
268
            .iter()
674
268
            .cloned()
675
268
            .enumerate()
676
2.11k
            .map(|(index, f)| {
677
2.11k
                (
678
2.11k
                    f,
679
2.11k
                    ColumnIndex {
680
2.11k
                        index,
681
2.11k
                        side: JoinSide::Left,
682
2.11k
                    },
683
2.11k
                )
684
2.11k
            })
685
268
            .unzip(),
686
264
        JoinType::RightSemi | JoinType::RightAnti => right
687
264
            .fields()
688
264
            .iter()
689
264
            .cloned()
690
264
            .enumerate()
691
2.10k
            .map(|(index, f)| {
692
2.10k
                (
693
2.10k
                    f,
694
2.10k
                    ColumnIndex {
695
2.10k
                        index,
696
2.10k
                        side: JoinSide::Right,
697
2.10k
                    },
698
2.10k
                )
699
2.10k
            })
700
264
            .unzip(),
701
    };
702
703
1.16k
    (fields.finish(), column_indices)
704
1.16k
}
705
706
/// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls
707
/// to [`OnceAsync::once`] returning a [`OnceFut`] to the same asynchronous computation
708
///
709
/// This is useful for joins where the results of one child are buffered in memory
710
/// and shared across potentially multiple output partitions
711
pub(crate) struct OnceAsync<T> {
712
    fut: Mutex<Option<OnceFut<T>>>,
713
}
714
715
impl<T> Default for OnceAsync<T> {
716
737
    fn default() -> Self {
717
737
        Self {
718
737
            fut: Mutex::new(None),
719
737
        }
720
737
    }
721
}
722
723
impl<T> std::fmt::Debug for OnceAsync<T> {
724
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
725
0
        write!(f, "OnceAsync")
726
0
    }
727
}
728
729
impl<T: 'static> OnceAsync<T> {
730
    /// If this is the first call to this function on this object, will invoke
731
    /// `f` to obtain a future and return a [`OnceFut`] referring to this
732
    ///
733
    /// If this is not the first call, will return a [`OnceFut`] referring
734
    /// to the same future as was returned by the first call
735
434
    pub(crate) fn once<F, Fut>(&self, f: F) -> OnceFut<T>
736
434
    where
737
434
        F: FnOnce() -> Fut,
738
434
        Fut: Future<Output = Result<T>> + Send + 'static,
739
434
    {
740
434
        self.fut
741
434
            .lock()
742
434
            .get_or_insert_with(|| 
OnceFut::new(f())381
)
743
434
            .clone()
744
434
    }
745
}
746
747
/// The shared future type used internally within [`OnceAsync`]
748
type OnceFutPending<T> = Shared<BoxFuture<'static, SharedResult<Arc<T>>>>;
749
750
/// A [`OnceFut`] represents a shared asynchronous computation, that will be evaluated
751
/// once for all [`Clone`]'s, with [`OnceFut::get`] providing a non-consuming interface
752
/// to drive the underlying [`Future`] to completion
753
pub(crate) struct OnceFut<T> {
754
    state: OnceFutState<T>,
755
}
756
757
impl<T> Clone for OnceFut<T> {
758
434
    fn clone(&self) -> Self {
759
434
        Self {
760
434
            state: self.state.clone(),
761
434
        }
762
434
    }
763
}
764
765
/// A shared state between statistic aggregators for a join
766
/// operation.
767
#[derive(Clone, Debug, Default)]
768
struct PartialJoinStatistics {
769
    pub num_rows: usize,
770
    pub column_statistics: Vec<ColumnStatistics>,
771
}
772
773
/// Estimate the statistics for the given join's output.
774
0
pub(crate) fn estimate_join_statistics(
775
0
    left: Arc<dyn ExecutionPlan>,
776
0
    right: Arc<dyn ExecutionPlan>,
777
0
    on: JoinOn,
778
0
    join_type: &JoinType,
779
0
    schema: &Schema,
780
0
) -> Result<Statistics> {
781
0
    let left_stats = left.statistics()?;
782
0
    let right_stats = right.statistics()?;
783
784
0
    let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on);
785
0
    let (num_rows, column_statistics) = match join_stats {
786
0
        Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics),
787
0
        None => (Precision::Absent, Statistics::unknown_column(schema)),
788
    };
789
0
    Ok(Statistics {
790
0
        num_rows,
791
0
        total_byte_size: Precision::Absent,
792
0
        column_statistics,
793
0
    })
794
0
}
795
796
// Estimate the cardinality for the given join with input statistics.
797
44
fn estimate_join_cardinality(
798
44
    join_type: &JoinType,
799
44
    left_stats: Statistics,
800
44
    right_stats: Statistics,
801
44
    on: &JoinOn,
802
44
) -> Option<PartialJoinStatistics> {
803
44
    let (left_col_stats, right_col_stats) = on
804
44
        .iter()
805
52
        .map(|(left, right)| {
806
52
            match (
807
52
                left.as_any().downcast_ref::<Column>(),
808
52
                right.as_any().downcast_ref::<Column>(),
809
            ) {
810
52
                (Some(left), Some(right)) => (
811
52
                    left_stats.column_statistics[left.index()].clone(),
812
52
                    right_stats.column_statistics[right.index()].clone(),
813
52
                ),
814
0
                _ => (
815
0
                    ColumnStatistics::new_unknown(),
816
0
                    ColumnStatistics::new_unknown(),
817
0
                ),
818
            }
819
52
        })
820
44
        .unzip::<_, _, Vec<_>, Vec<_>>();
821
44
822
44
    match join_type {
823
        JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
824
29
            let 
ij_cardinality23
= estimate_inner_join_cardinality(
825
29
                Statistics {
826
29
                    num_rows: left_stats.num_rows,
827
29
                    total_byte_size: Precision::Absent,
828
29
                    column_statistics: left_col_stats,
829
29
                },
830
29
                Statistics {
831
29
                    num_rows: right_stats.num_rows,
832
29
                    total_byte_size: Precision::Absent,
833
29
                    column_statistics: right_col_stats,
834
29
                },
835
29
            )
?6
;
836
837
            // The cardinality for inner join can also be used to estimate
838
            // the cardinality of left/right/full outer joins as long as it
839
            // it is greater than the minimum cardinality constraints of these
840
            // joins (so that we don't underestimate the cardinality).
841
23
            let cardinality = match join_type {
842
17
                JoinType::Inner => ij_cardinality,
843
2
                JoinType::Left => ij_cardinality.max(&left_stats.num_rows),
844
2
                JoinType::Right => ij_cardinality.max(&right_stats.num_rows),
845
2
                JoinType::Full => ij_cardinality
846
2
                    .max(&left_stats.num_rows)
847
2
                    .add(&ij_cardinality.max(&right_stats.num_rows))
848
2
                    .sub(&ij_cardinality),
849
0
                _ => unreachable!(),
850
            };
851
852
            Some(PartialJoinStatistics {
853
23
                num_rows: *cardinality.get_value()
?0
,
854
                // We don't do anything specific here, just combine the existing
855
                // statistics which might yield subpar results (although it is
856
                // true, esp regarding min/max). For a better estimation, we need
857
                // filter selectivity analysis first.
858
23
                column_statistics: left_stats
859
23
                    .column_statistics
860
23
                    .into_iter()
861
23
                    .chain(right_stats.column_statistics)
862
23
                    .collect(),
863
            })
864
        }
865
866
        // For SemiJoins estimation result is either zero, in cases when inputs
867
        // are non-overlapping according to statistics, or equal to number of rows
868
        // for outer input
869
        JoinType::LeftSemi | JoinType::RightSemi => {
870
9
            let (outer_stats, inner_stats) = match join_type {
871
8
                JoinType::LeftSemi => (left_stats, right_stats),
872
1
                _ => (right_stats, left_stats),
873
            };
874
9
            let 
cardinality7
= match estimate_disjoint_inputs(&outer_stats, &inner_stats) {
875
3
                Some(estimation) => *estimation.get_value()
?0
,
876
6
                None => *outer_stats.num_rows.get_value()
?2
,
877
            };
878
879
7
            Some(PartialJoinStatistics {
880
7
                num_rows: cardinality,
881
7
                column_statistics: outer_stats.column_statistics,
882
7
            })
883
        }
884
885
        // For AntiJoins estimation always equals to outer statistics, as
886
        // non-overlapping inputs won't affect estimation
887
        JoinType::LeftAnti | JoinType::RightAnti => {
888
6
            let outer_stats = match join_type {
889
5
                JoinType::LeftAnti => left_stats,
890
1
                _ => right_stats,
891
            };
892
893
            Some(PartialJoinStatistics {
894
6
                num_rows: *outer_stats.num_rows.get_value()
?0
,
895
6
                column_statistics: outer_stats.column_statistics,
896
            })
897
        }
898
    }
899
44
}
900
901
/// Estimate the inner join cardinality by using the basic building blocks of
902
/// column-level statistics and the total row count. This is a very naive and
903
/// a very conservative implementation that can quickly give up if there is not
904
/// enough input statistics.
905
52
fn estimate_inner_join_cardinality(
906
52
    left_stats: Statistics,
907
52
    right_stats: Statistics,
908
52
) -> Option<Precision<usize>> {
909
    // Immediately return if inputs considered as non-overlapping
910
52
    if let Some(
estimation10
) = estimate_disjoint_inputs(&left_stats, &right_stats) {
911
10
        return Some(estimation);
912
42
    };
913
42
914
42
    // The algorithm here is partly based on the non-histogram selectivity estimation
915
42
    // from Spark's Catalyst optimizer.
916
42
    let mut join_selectivity = Precision::Absent;
917
47
    for (left_stat, right_stat) in left_stats
918
42
        .column_statistics
919
42
        .iter()
920
42
        .zip(right_stats.column_statistics.iter())
921
    {
922
        // Break if any of statistics bounds are undefined
923
47
        if left_stat.min_value.get_value().is_none()
924
39
            || left_stat.max_value.get_value().is_none()
925
37
            || right_stat.min_value.get_value().is_none()
926
37
            || right_stat.max_value.get_value().is_none()
927
        {
928
10
            return None;
929
37
        }
930
37
931
37
        let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat);
932
37
        let right_max_distinct = max_distinct_count(&right_stats.num_rows, right_stat);
933
37
        let max_distinct = left_max_distinct.max(&right_max_distinct);
934
37
        if max_distinct.get_value().is_some() {
935
37
            // Seems like there are a few implementations of this algorithm that implement
936
37
            // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
937
37
            // further exploration.
938
37
            join_selectivity = max_distinct;
939
37
        }
0
940
    }
941
942
    // With the assumption that the smaller input's domain is generally represented in the bigger
943
    // input's domain, we can estimate the inner join's cardinality by taking the cartesian product
944
    // of the two inputs and normalizing it by the selectivity factor.
945
32
    let left_num_rows = left_stats.num_rows.get_value()
?0
;
946
32
    let right_num_rows = right_stats.num_rows.get_value()
?0
;
947
0
    match join_selectivity {
948
0
        Precision::Exact(value) if value > 0 => {
949
0
            Some(Precision::Exact((left_num_rows * right_num_rows) / value))
950
        }
951
32
        Precision::Inexact(
value30
) if value > 0 => {
952
30
            Some(Precision::Inexact((left_num_rows * right_num_rows) / value))
953
        }
954
        // Since we don't have any information about the selectivity (which is derived
955
        // from the number of distinct rows information) we can give up here for now.
956
        // And let other passes handle this (otherwise we would need to produce an
957
        // overestimation using just the cartesian product).
958
2
        _ => None,
959
    }
960
52
}
961
962
/// Estimates if inputs are non-overlapping, using input statistics.
963
/// If inputs are disjoint, returns zero estimation, otherwise returns None
964
61
fn estimate_disjoint_inputs(
965
61
    left_stats: &Statistics,
966
61
    right_stats: &Statistics,
967
61
) -> Option<Precision<usize>> {
968
70
    for (left_stat, right_stat) in left_stats
969
61
        .column_statistics
970
61
        .iter()
971
61
        .zip(right_stats.column_statistics.iter())
972
    {
973
        // If there is no overlap in any of the join columns, this means the join
974
        // itself is disjoint and the cardinality is 0. Though we can only assume
975
        // this when the statistics are exact (since it is a very strong assumption).
976
70
        let left_min_val = left_stat.min_value.get_value();
977
70
        let right_max_val = right_stat.max_value.get_value();
978
70
        if left_min_val.is_some()
979
55
            && right_max_val.is_some()
980
55
            && left_min_val > right_max_val
981
        {
982
            return Some(
983
7
                if left_stat.min_value.is_exact().unwrap_or(false)
984
0
                    && right_stat.max_value.is_exact().unwrap_or(false)
985
                {
986
0
                    Precision::Exact(0)
987
                } else {
988
7
                    Precision::Inexact(0)
989
                },
990
            );
991
63
        }
992
63
993
63
        let left_max_val = left_stat.max_value.get_value();
994
63
        let right_min_val = right_stat.min_value.get_value();
995
63
        if left_max_val.is_some()
996
53
            && right_min_val.is_some()
997
53
            && left_max_val < right_min_val
998
        {
999
            return Some(
1000
6
                if left_stat.max_value.is_exact().unwrap_or(false)
1001
0
                    && right_stat.min_value.is_exact().unwrap_or(false)
1002
                {
1003
0
                    Precision::Exact(0)
1004
                } else {
1005
6
                    Precision::Inexact(0)
1006
                },
1007
            );
1008
57
        }
1009
    }
1010
1011
48
    None
1012
61
}
1013
1014
/// Estimate the number of maximum distinct values that can be present in the
1015
/// given column from its statistics. If distinct_count is available, uses it
1016
/// directly. Otherwise, if the column is numeric and has min/max values, it
1017
/// estimates the maximum distinct count from those.
1018
74
fn max_distinct_count(
1019
74
    num_rows: &Precision<usize>,
1020
74
    stats: &ColumnStatistics,
1021
74
) -> Precision<usize> {
1022
74
    match &stats.distinct_count {
1023
38
        &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc,
1024
        _ => {
1025
            // The number can never be greater than the number of rows we have
1026
            // minus the nulls (since they don't count as distinct values).
1027
36
            let result = match num_rows {
1028
0
                Precision::Absent => Precision::Absent,
1029
36
                Precision::Inexact(count) => {
1030
36
                    // To safeguard against inexact number of rows (e.g. 0) being smaller than
1031
36
                    // an exact null count we need to do a checked subtraction.
1032
36
                    match count.checked_sub(*stats.null_count.get_value().unwrap_or(&0)) {
1033
2
                        None => Precision::Inexact(0),
1034
34
                        Some(non_null_count) => Precision::Inexact(non_null_count),
1035
                    }
1036
                }
1037
0
                Precision::Exact(count) => {
1038
0
                    let count = count - stats.null_count.get_value().unwrap_or(&0);
1039
0
                    if stats.null_count.is_exact().unwrap_or(false) {
1040
0
                        Precision::Exact(count)
1041
                    } else {
1042
0
                        Precision::Inexact(count)
1043
                    }
1044
                }
1045
            };
1046
            // Cap the estimate using the number of possible values:
1047
36
            if let (Some(min), Some(max)) =
1048
36
                (stats.min_value.get_value(), stats.max_value.get_value())
1049
            {
1050
36
                if let Some(
range_dc34
) = Interval::try_new(min.clone(), max.clone())
1051
36
                    .ok()
1052
36
                    .and_then(|e| e.cardinality())
1053
                {
1054
34
                    let range_dc = range_dc as usize;
1055
                    // Note that the `unwrap` calls in the below statement are safe.
1056
34
                    return if matches!(result, Precision::Absent)
1057
34
                        || &range_dc < result.get_value().unwrap()
1058
                    {
1059
16
                        if stats.min_value.is_exact().unwrap()
1060
0
                            && stats.max_value.is_exact().unwrap()
1061
                        {
1062
0
                            Precision::Exact(range_dc)
1063
                        } else {
1064
16
                            Precision::Inexact(range_dc)
1065
                        }
1066
                    } else {
1067
18
                        result
1068
                    };
1069
2
                }
1070
0
            }
1071
1072
2
            result
1073
        }
1074
    }
1075
74
}
1076
1077
enum OnceFutState<T> {
1078
    Pending(OnceFutPending<T>),
1079
    Ready(SharedResult<Arc<T>>),
1080
}
1081
1082
impl<T> Clone for OnceFutState<T> {
1083
434
    fn clone(&self) -> Self {
1084
434
        match self {
1085
434
            Self::Pending(p) => Self::Pending(p.clone()),
1086
0
            Self::Ready(r) => Self::Ready(r.clone()),
1087
        }
1088
434
    }
1089
}
1090
1091
impl<T: 'static> OnceFut<T> {
1092
    /// Create a new [`OnceFut`] from a [`Future`]
1093
1.78k
    pub(crate) fn new<Fut>(fut: Fut) -> Self
1094
1.78k
    where
1095
1.78k
        Fut: Future<Output = Result<T>> + Send + 'static,
1096
1.78k
    {
1097
1.78k
        Self {
1098
1.78k
            state: OnceFutState::Pending(
1099
1.78k
                fut.map(|res| res.map(Arc::new).map_err(Arc::new))
1100
1.78k
                    .boxed()
1101
1.78k
                    .shared(),
1102
1.78k
            ),
1103
1.78k
        }
1104
1.78k
    }
1105
1106
    /// Get the result of the computation if it is ready, without consuming it
1107
3
    pub(crate) fn get(&mut self, cx: &mut Context<'_>) -> Poll<Result<&T>> {
1108
3
        if let OnceFutState::Pending(fut) = &mut self.state {
1109
3
            let r = 
ready!0
(fut.poll_unpin(cx));
1110
3
            self.state = OnceFutState::Ready(r);
1111
0
        }
1112
1113
        // Cannot use loop as this would trip up the borrow checker
1114
3
        match &self.state {
1115
0
            OnceFutState::Pending(_) => unreachable!(),
1116
3
            OnceFutState::Ready(r) => Poll::Ready(
1117
3
                r.as_ref()
1118
3
                    .map(|r| 
r.as_ref()1
)
1119
3
                    .map_err(|e| 
DataFusionError::External(Box::new(Arc::clone(e)))2
),
1120
3
            ),
1121
        }
1122
3
    }
1123
1124
    /// Get shared reference to the result of the computation if it is ready, without consuming it
1125
15.4k
    pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll<Result<Arc<T>>> {
1126
15.4k
        if let OnceFutState::Pending(
fut3.32k
) = &mut self.state {
1127
3.32k
            let 
r1.83k
=
ready!1.49k
(fut.poll_unpin(cx));
1128
1.83k
            self.state = OnceFutState::Ready(r);
1129
12.1k
        }
1130
1131
13.9k
        match &self.state {
1132
0
            OnceFutState::Pending(_) => unreachable!(),
1133
13.9k
            OnceFutState::Ready(r) => Poll::Ready(
1134
13.9k
                r.clone()
1135
13.9k
                    .map_err(|e| 
DataFusionError::External(Box::new(e))24
),
1136
13.9k
            ),
1137
        }
1138
15.4k
    }
1139
}
1140
1141
/// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and
1142
/// use the bit map to generate the part of result of the join.
1143
///
1144
/// For example of the `Left` join, in each iteration of right side, can get the matched result, but need
1145
/// to maintain the matched indices bit map to get the unmatched row for the left side.
1146
20.8k
pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
1147
16.5k
    matches!(
1148
20.8k
        join_type,
1149
        JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full
1150
    )
1151
20.8k
}
1152
1153
/// In the end of join execution, need to use bit map of the matched
1154
/// indices to generate the final left and right indices.
1155
///
1156
/// For example:
1157
///
1158
/// 1. left_bit_map: `[true, false, true, true, false]`
1159
/// 2. join_type: `Left`
1160
///
1161
/// The result is: `([1,4], [null, null])`
1162
854
pub(crate) fn get_final_indices_from_bit_map(
1163
854
    left_bit_map: &BooleanBufferBuilder,
1164
854
    join_type: JoinType,
1165
854
) -> (UInt64Array, UInt32Array) {
1166
854
    let left_size = left_bit_map.len();
1167
854
    let left_indices = if join_type == JoinType::LeftSemi {
1168
201
        (0..left_size)
1169
1.42k
            .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
1170
201
            .collect::<UInt64Array>()
1171
    } else {
1172
        // just for `Left`, `LeftAnti` and `Full` join
1173
        // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
1174
653
        (0..left_size)
1175
4.34k
            .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
1176
653
            .collect::<UInt64Array>()
1177
    };
1178
    // right_indices
1179
    // all the element in the right side is None
1180
854
    let mut builder = UInt32Builder::with_capacity(left_indices.len());
1181
854
    builder.append_nulls(left_indices.len());
1182
854
    let right_indices = builder.finish();
1183
854
    (left_indices, right_indices)
1184
854
}
1185
1186
23.0k
pub(crate) fn apply_join_filter_to_indices(
1187
23.0k
    build_input_buffer: &RecordBatch,
1188
23.0k
    probe_batch: &RecordBatch,
1189
23.0k
    build_indices: UInt64Array,
1190
23.0k
    probe_indices: UInt32Array,
1191
23.0k
    filter: &JoinFilter,
1192
23.0k
    build_side: JoinSide,
1193
23.0k
) -> Result<(UInt64Array, UInt32Array)> {
1194
23.0k
    if build_indices.is_empty() && 
probe_indices.is_empty()872
{
1195
872
        return Ok((build_indices, probe_indices));
1196
22.2k
    };
1197
1198
22.2k
    let intermediate_batch = build_batch_from_indices(
1199
22.2k
        filter.schema(),
1200
22.2k
        build_input_buffer,
1201
22.2k
        probe_batch,
1202
22.2k
        &build_indices,
1203
22.2k
        &probe_indices,
1204
22.2k
        filter.column_indices(),
1205
22.2k
        build_side,
1206
22.2k
    )
?0
;
1207
22.2k
    let filter_result = filter
1208
22.2k
        .expression()
1209
22.2k
        .evaluate(&intermediate_batch)
?0
1210
22.2k
        .into_array(intermediate_batch.num_rows())
?0
;
1211
22.2k
    let mask = as_boolean_array(&filter_result)
?0
;
1212
1213
22.2k
    let left_filtered = compute::filter(&build_indices, mask)
?0
;
1214
22.2k
    let right_filtered = compute::filter(&probe_indices, mask)
?0
;
1215
22.2k
    Ok((
1216
22.2k
        downcast_array(left_filtered.as_ref()),
1217
22.2k
        downcast_array(right_filtered.as_ref()),
1218
22.2k
    ))
1219
23.0k
}
1220
1221
/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
1222
/// The resulting batch has [Schema] `schema`.
1223
46.6k
pub(crate) fn build_batch_from_indices(
1224
46.6k
    schema: &Schema,
1225
46.6k
    build_input_buffer: &RecordBatch,
1226
46.6k
    probe_batch: &RecordBatch,
1227
46.6k
    build_indices: &UInt64Array,
1228
46.6k
    probe_indices: &UInt32Array,
1229
46.6k
    column_indices: &[ColumnIndex],
1230
46.6k
    build_side: JoinSide,
1231
46.6k
) -> Result<RecordBatch> {
1232
46.6k
    if schema.fields().is_empty() {
1233
0
        let options = RecordBatchOptions::new()
1234
0
            .with_match_field_names(true)
1235
0
            .with_row_count(Some(build_indices.len()));
1236
0
1237
0
        return Ok(RecordBatch::try_new_with_options(
1238
0
            Arc::new(schema.clone()),
1239
0
            vec![],
1240
0
            &options,
1241
0
        )?);
1242
46.6k
    }
1243
46.6k
1244
46.6k
    // build the columns of the new [RecordBatch]:
1245
46.6k
    // 1. pick whether the column is from the left or right
1246
46.6k
    // 2. based on the pick, `take` items from the different RecordBatches
1247
46.6k
    let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
1248
1249
355k
    for 
column_index308k
in column_indices {
1250
308k
        let array = if column_index.side == build_side {
1251
154k
            let array = build_input_buffer.column(column_index.index);
1252
154k
            if array.is_empty() || 
build_indices.null_count() == build_indices.len()150k
{
1253
                // Outer join would generate a null index when finding no match at our side.
1254
                // Therefore, it's possible we are empty but need to populate an n-length null array,
1255
                // where n is the length of the index array.
1256
72.3k
                assert_eq!(build_indices.null_count(), build_indices.len());
1257
72.3k
                new_null_array(array.data_type(), build_indices.len())
1258
            } else {
1259
82.2k
                compute::take(array.as_ref(), build_indices, None)
?0
1260
            }
1261
        } else {
1262
153k
            let array = probe_batch.column(column_index.index);
1263
153k
            if array.is_empty() || 
probe_indices.null_count() == probe_indices.len()133k
{
1264
73.0k
                assert_eq!(probe_indices.null_count(), probe_indices.len());
1265
73.0k
                new_null_array(array.data_type(), probe_indices.len())
1266
            } else {
1267
80.8k
                compute::take(array.as_ref(), probe_indices, None)
?0
1268
            }
1269
        };
1270
308k
        columns.push(array);
1271
    }
1272
46.6k
    Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)
?0
)
1273
46.6k
}
1274
1275
/// The input is the matched indices for left and right and
1276
/// adjust the indices according to the join type
1277
17.3k
pub(crate) fn adjust_indices_by_join_type(
1278
17.3k
    left_indices: UInt64Array,
1279
17.3k
    right_indices: UInt32Array,
1280
17.3k
    adjust_range: Range<usize>,
1281
17.3k
    join_type: JoinType,
1282
17.3k
    preserve_order_for_right: bool,
1283
17.3k
) -> (UInt64Array, UInt32Array) {
1284
17.3k
    match join_type {
1285
        JoinType::Inner => {
1286
            // matched
1287
3.70k
            (left_indices, right_indices)
1288
        }
1289
        JoinType::Left => {
1290
            // matched
1291
650
            (left_indices, right_indices)
1292
            // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
1293
        }
1294
        JoinType::Right => {
1295
            // combine the matched and unmatched right result together
1296
3.66k
            append_right_indices(
1297
3.66k
                left_indices,
1298
3.66k
                right_indices,
1299
3.66k
                adjust_range,
1300
3.66k
                preserve_order_for_right,
1301
3.66k
            )
1302
        }
1303
        JoinType::Full => {
1304
696
            append_right_indices(left_indices, right_indices, adjust_range, false)
1305
        }
1306
        JoinType::RightSemi => {
1307
            // need to remove the duplicated record in the right side
1308
3.66k
            let right_indices = get_semi_indices(adjust_range, &right_indices);
1309
3.66k
            // the left_indices will not be used later for the `right semi` join
1310
3.66k
            (left_indices, right_indices)
1311
        }
1312
        JoinType::RightAnti => {
1313
            // need to remove the duplicated record in the right side
1314
            // get the anti index for the right side
1315
3.66k
            let right_indices = get_anti_indices(adjust_range, &right_indices);
1316
3.66k
            // the left_indices will not be used later for the `right anti` join
1317
3.66k
            (left_indices, right_indices)
1318
        }
1319
        JoinType::LeftSemi | JoinType::LeftAnti => {
1320
            // matched or unmatched left row will be produced in the end of loop
1321
            // When visit the right batch, we can output the matched left row and don't need to wait the end of loop
1322
1.26k
            (
1323
1.26k
                UInt64Array::from_iter_values(vec![]),
1324
1.26k
                UInt32Array::from_iter_values(vec![]),
1325
1.26k
            )
1326
        }
1327
    }
1328
17.3k
}
1329
1330
/// Appends right indices to left indices based on the specified order mode.
1331
///
1332
/// The function operates in two modes:
1333
/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices
1334
///    are inserted in order using the `append_probe_indices_in_order()` method.
1335
/// 2. Otherwise, unmatched probe indices are simply appended after matched ones.
1336
///
1337
/// # Parameters
1338
/// - `left_indices`: UInt64Array of left indices.
1339
/// - `right_indices`: UInt32Array of right indices.
1340
/// - `adjust_range`: Range to adjust the right indices.
1341
/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation.
1342
///
1343
/// # Returns
1344
/// A tuple of updated `UInt64Array` and `UInt32Array`.
1345
4.35k
pub(crate) fn append_right_indices(
1346
4.35k
    left_indices: UInt64Array,
1347
4.35k
    right_indices: UInt32Array,
1348
4.35k
    adjust_range: Range<usize>,
1349
4.35k
    preserve_order_for_right: bool,
1350
4.35k
) -> (UInt64Array, UInt32Array) {
1351
4.35k
    if preserve_order_for_right {
1352
3.46k
        append_probe_indices_in_order(left_indices, right_indices, adjust_range)
1353
    } else {
1354
895
        let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices);
1355
895
1356
895
        if right_unmatched_indices.is_empty() {
1357
329
            (left_indices, right_indices)
1358
        } else {
1359
566
            let unmatched_size = right_unmatched_indices.len();
1360
566
            // the new left indices: left_indices + null array
1361
566
            // the new right indices: right_indices + right_unmatched_indices
1362
566
            let new_left_indices = left_indices
1363
566
                .iter()
1364
566
                .chain(std::iter::repeat(None).take(unmatched_size))
1365
566
                .collect();
1366
566
            let new_right_indices = right_indices
1367
566
                .iter()
1368
566
                .chain(right_unmatched_indices.iter())
1369
566
                .collect();
1370
566
            (new_left_indices, new_right_indices)
1371
        }
1372
    }
1373
4.35k
}
1374
1375
/// Returns `range` indices which are not present in `input_indices`
1376
4.56k
pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
1377
4.56k
    range: Range<usize>,
1378
4.56k
    input_indices: &PrimitiveArray<T>,
1379
4.56k
) -> PrimitiveArray<T>
1380
4.56k
where
1381
4.56k
    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1382
4.56k
{
1383
4.56k
    let mut bitmap = BooleanBufferBuilder::new(range.len());
1384
4.56k
    bitmap.append_n(range.len(), false);
1385
4.56k
    input_indices
1386
4.56k
        .iter()
1387
4.56k
        .flatten()
1388
4.80M
        .map(|v| v.as_usize())
1389
4.80M
        .filter(|v| range.contains(v))
1390
4.80M
        .for_each(|v| {
1391
4.80M
            bitmap.set_bit(v - range.start, true);
1392
4.80M
        });
1393
4.56k
1394
4.56k
    let offset = range.start;
1395
4.56k
1396
4.56k
    // get the anti index
1397
4.56k
    (range)
1398
12.3k
        .filter_map(|idx| {
1399
12.3k
            (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1400
12.3k
        })
1401
4.56k
        .collect()
1402
4.56k
}
1403
1404
/// Returns intersection of `range` and `input_indices` omitting duplicates
1405
3.66k
pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
1406
3.66k
    range: Range<usize>,
1407
3.66k
    input_indices: &PrimitiveArray<T>,
1408
3.66k
) -> PrimitiveArray<T>
1409
3.66k
where
1410
3.66k
    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
1411
3.66k
{
1412
3.66k
    let mut bitmap = BooleanBufferBuilder::new(range.len());
1413
3.66k
    bitmap.append_n(range.len(), false);
1414
3.66k
    input_indices
1415
3.66k
        .iter()
1416
3.66k
        .flatten()
1417
4.80M
        .map(|v| v.as_usize())
1418
4.80M
        .filter(|v| range.contains(v))
1419
4.80M
        .for_each(|v| {
1420
4.80M
            bitmap.set_bit(v - range.start, true);
1421
4.80M
        });
1422
3.66k
1423
3.66k
    let offset = range.start;
1424
3.66k
1425
3.66k
    // get the semi index
1426
3.66k
    (range)
1427
10.4k
        .filter_map(|idx| {
1428
10.4k
            (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
1429
10.4k
        })
1430
3.66k
        .collect()
1431
3.66k
}
1432
1433
/// Appends probe indices in order by considering the given build indices.
1434
///
1435
/// This function constructs new build and probe indices by iterating through
1436
/// the provided indices, and appends any missing values between previous and
1437
/// current probe index with a corresponding null build index.
1438
///
1439
/// # Parameters
1440
///
1441
/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices.
1442
/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices.
1443
/// - `range`: The range of indices to consider.
1444
///
1445
/// # Returns
1446
///
1447
/// A tuple of two arrays:
1448
/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices.
1449
/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices.
1450
3.46k
fn append_probe_indices_in_order(
1451
3.46k
    build_indices: PrimitiveArray<UInt64Type>,
1452
3.46k
    probe_indices: PrimitiveArray<UInt32Type>,
1453
3.46k
    range: Range<usize>,
1454
3.46k
) -> (PrimitiveArray<UInt64Type>, PrimitiveArray<UInt32Type>) {
1455
3.46k
    // Builders for new indices:
1456
3.46k
    let mut new_build_indices = UInt64Builder::new();
1457
3.46k
    let mut new_probe_indices = UInt32Builder::new();
1458
3.46k
    // Set previous index as the start index for the initial loop:
1459
3.46k
    let mut prev_index = range.start as u32;
1460
3.46k
    // Zip the two iterators.
1461
3.46k
    debug_assert!(build_indices.len() == probe_indices.len());
1462
4.80M
    for (build_index, probe_index) in build_indices
1463
3.46k
        .values()
1464
3.46k
        .into_iter()
1465
3.46k
        .zip(probe_indices.values().into_iter())
1466
    {
1467
        // Append values between previous and current probe index with null build index:
1468
4.80M
        for 
value1.20k
in prev_index..*probe_index {
1469
1.20k
            new_probe_indices.append_value(value);
1470
1.20k
            new_build_indices.append_null();
1471
1.20k
        }
1472
        // Append current indices:
1473
4.80M
        new_probe_indices.append_value(*probe_index);
1474
4.80M
        new_build_indices.append_value(*build_index);
1475
4.80M
        // Set current probe index as previous for the next iteration:
1476
4.80M
        prev_index = probe_index + 1;
1477
    }
1478
    // Append remaining probe indices after the last valid probe index with null build index.
1479
3.46k
    for 
value1.39k
in prev_index..range.end as u32 {
1480
1.39k
        new_probe_indices.append_value(value);
1481
1.39k
        new_build_indices.append_null();
1482
1.39k
    }
1483
    // Build arrays and return:
1484
3.46k
    (new_build_indices.finish(), new_probe_indices.finish())
1485
3.46k
}
1486
1487
/// Metrics for build & probe joins
1488
#[derive(Clone, Debug)]
1489
pub(crate) struct BuildProbeJoinMetrics {
1490
    /// Total time for collecting build-side of join
1491
    pub(crate) build_time: metrics::Time,
1492
    /// Number of batches consumed by build-side
1493
    pub(crate) build_input_batches: metrics::Count,
1494
    /// Number of rows consumed by build-side
1495
    pub(crate) build_input_rows: metrics::Count,
1496
    /// Memory used by build-side in bytes
1497
    pub(crate) build_mem_used: metrics::Gauge,
1498
    /// Total time for joining probe-side batches to the build-side batches
1499
    pub(crate) join_time: metrics::Time,
1500
    /// Number of batches consumed by probe-side of this operator
1501
    pub(crate) input_batches: metrics::Count,
1502
    /// Number of rows consumed by probe-side this operator
1503
    pub(crate) input_rows: metrics::Count,
1504
    /// Number of batches produced by this operator
1505
    pub(crate) output_batches: metrics::Count,
1506
    /// Number of rows produced by this operator
1507
    pub(crate) output_rows: metrics::Count,
1508
}
1509
1510
impl BuildProbeJoinMetrics {
1511
1.83k
    pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
1512
1.83k
        let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition);
1513
1.83k
1514
1.83k
        let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition);
1515
1.83k
1516
1.83k
        let build_input_batches =
1517
1.83k
            MetricBuilder::new(metrics).counter("build_input_batches", partition);
1518
1.83k
1519
1.83k
        let build_input_rows =
1520
1.83k
            MetricBuilder::new(metrics).counter("build_input_rows", partition);
1521
1.83k
1522
1.83k
        let build_mem_used =
1523
1.83k
            MetricBuilder::new(metrics).gauge("build_mem_used", partition);
1524
1.83k
1525
1.83k
        let input_batches =
1526
1.83k
            MetricBuilder::new(metrics).counter("input_batches", partition);
1527
1.83k
1528
1.83k
        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
1529
1.83k
1530
1.83k
        let output_batches =
1531
1.83k
            MetricBuilder::new(metrics).counter("output_batches", partition);
1532
1.83k
1533
1.83k
        let output_rows = MetricBuilder::new(metrics).output_rows(partition);
1534
1.83k
1535
1.83k
        Self {
1536
1.83k
            build_time,
1537
1.83k
            build_input_batches,
1538
1.83k
            build_input_rows,
1539
1.83k
            build_mem_used,
1540
1.83k
            join_time,
1541
1.83k
            input_batches,
1542
1.83k
            input_rows,
1543
1.83k
            output_batches,
1544
1.83k
            output_rows,
1545
1.83k
        }
1546
1.83k
    }
1547
}
1548
1549
/// The `handle_state` macro is designed to process the result of a state-changing
1550
/// operation. It operates on a `StatefulStreamResult` by matching its variants and
1551
/// executing corresponding actions. This macro is used to streamline code that deals
1552
/// with state transitions, reducing boilerplate and improving readability.
1553
///
1554
/// # Cases
1555
///
1556
/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the
1557
///   stream join operation should proceed to the next step.
1558
/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the
1559
///   result, either yielding a value or indicating the stream is awaiting more
1560
///   data.
1561
/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
1562
///   during the stream join operation.
1563
///
1564
/// # Arguments
1565
///
1566
/// * `$match_case`: An expression that evaluates to a `Result<StatefulStreamResult<_>>`.
1567
#[macro_export]
1568
macro_rules! handle_state {
1569
    ($match_case:expr) => {
1570
        match $match_case {
1571
            Ok(StatefulStreamResult::Continue) => continue,
1572
            Ok(StatefulStreamResult::Ready(result)) => {
1573
                Poll::Ready(Ok(result).transpose())
1574
            }
1575
            Err(e) => Poll::Ready(Some(Err(e))),
1576
        }
1577
    };
1578
}
1579
1580
/// Represents the result of a stateful operation.
1581
///
1582
/// This enumueration indicates whether the state produced a result that is
1583
/// ready for use (`Ready`) or if the operation requires continuation (`Continue`).
1584
///
1585
/// Variants:
1586
/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`.
1587
/// - `Continue`: Indicates that the operation is not yet complete and requires further
1588
///   processing or more data. When this variant is returned, it typically means that the
1589
///   current invocation of the state did not produce a final result, and the operation
1590
///   should be invoked again later with more data and possibly with a different state.
1591
pub enum StatefulStreamResult<T> {
1592
    Ready(T),
1593
    Continue,
1594
}
1595
1596
768
pub(crate) fn symmetric_join_output_partitioning(
1597
768
    left: &Arc<dyn ExecutionPlan>,
1598
768
    right: &Arc<dyn ExecutionPlan>,
1599
768
    join_type: &JoinType,
1600
768
) -> Partitioning {
1601
768
    let left_columns_len = left.schema().fields.len();
1602
768
    let left_partitioning = left.output_partitioning();
1603
768
    let right_partitioning = right.output_partitioning();
1604
768
    match join_type {
1605
        JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
1606
289
            left_partitioning.clone()
1607
        }
1608
166
        JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(),
1609
        JoinType::Inner | JoinType::Right => {
1610
208
            adjust_right_output_partitioning(right_partitioning, left_columns_len)
1611
        }
1612
        JoinType::Full => {
1613
            // We could also use left partition count as they are necessarily equal.
1614
105
            Partitioning::UnknownPartitioning(right_partitioning.partition_count())
1615
        }
1616
    }
1617
768
}
1618
1619
379
pub(crate) fn asymmetric_join_output_partitioning(
1620
379
    left: &Arc<dyn ExecutionPlan>,
1621
379
    right: &Arc<dyn ExecutionPlan>,
1622
379
    join_type: &JoinType,
1623
379
) -> Partitioning {
1624
379
    match join_type {
1625
111
        JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
1626
111
            right.output_partitioning(),
1627
111
            left.schema().fields().len(),
1628
111
        ),
1629
98
        JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(),
1630
        JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => {
1631
170
            Partitioning::UnknownPartitioning(
1632
170
                right.output_partitioning().partition_count(),
1633
170
            )
1634
        }
1635
    }
1636
379
}
1637
1638
#[cfg(test)]
1639
mod tests {
1640
    use std::pin::Pin;
1641
1642
    use super::*;
1643
1644
    use arrow::datatypes::{DataType, Fields};
1645
    use arrow::error::{ArrowError, Result as ArrowResult};
1646
    use arrow_schema::SortOptions;
1647
1648
    use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
1649
    use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
1650
1651
5
    fn check(
1652
5
        left: &[Column],
1653
5
        right: &[Column],
1654
5
        on: &[(PhysicalExprRef, PhysicalExprRef)],
1655
5
    ) -> Result<()> {
1656
5
        let left = left
1657
5
            .iter()
1658
9
            .map(|x| x.to_owned())
1659
5
            .collect::<HashSet<Column>>();
1660
5
        let right = right
1661
5
            .iter()
1662
7
            .map(|x| x.to_owned())
1663
5
            .collect::<HashSet<Column>>();
1664
5
        check_join_set_is_valid(&left, &right, on)
1665
5
    }
1666
1667
    #[test]
1668
1
    fn check_valid() -> Result<()> {
1669
1
        let left = vec![Column::new("a", 0), Column::new("b1", 1)];
1670
1
        let right = vec![Column::new("a", 0), Column::new("b2", 1)];
1671
1
        let on = &[(
1672
1
            Arc::new(Column::new("a", 0)) as _,
1673
1
            Arc::new(Column::new("a", 0)) as _,
1674
1
        )];
1675
1
1676
1
        check(&left, &right, on)
?0
;
1677
1
        Ok(())
1678
1
    }
1679
1680
    #[test]
1681
1
    fn check_not_in_right() {
1682
1
        let left = vec![Column::new("a", 0), Column::new("b", 1)];
1683
1
        let right = vec![Column::new("b", 0)];
1684
1
        let on = &[(
1685
1
            Arc::new(Column::new("a", 0)) as _,
1686
1
            Arc::new(Column::new("a", 0)) as _,
1687
1
        )];
1688
1
1689
1
        assert!(check(&left, &right, on).is_err());
1690
1
    }
1691
1692
    #[tokio::test]
1693
1
    async fn check_error_nesting() {
1694
1
        let once_fut = OnceFut::<()>::new(async {
1695
1
            arrow_err!(ArrowError::CsvError("some error".to_string()))
1696
1
        });
1697
1
1698
1
        struct TestFut(OnceFut<()>);
1699
1
        impl Future for TestFut {
1700
1
            type Output = ArrowResult<()>;
1701
1
1702
1
            fn poll(
1703
1
                mut self: Pin<&mut Self>,
1704
1
                cx: &mut Context<'_>,
1705
1
            ) -> Poll<Self::Output> {
1706
1
                match 
ready!0
(self.0.get(cx)) {
1707
1
                    Ok(()) => 
Poll::Ready(Ok(()))0
,
1708
1
                    Err(e) => Poll::Ready(Err(e.into())),
1709
1
                }
1710
1
            }
1711
1
        }
1712
1
1713
1
        let res = TestFut(once_fut).
await0
;
1714
1
        let arrow_err_from_fut = res.expect_err("once_fut always return error");
1715
1
1716
1
        let wrapped_err = DataFusionError::from(arrow_err_from_fut);
1717
1
        let root_err = wrapped_err.find_root();
1718
1
1719
1
        let _expected =
1720
1
            arrow_datafusion_err!(ArrowError::CsvError("some error".to_owned()));
1721
1
1722
1
        assert!(matches!(root_err, _expected))
1723
1
    }
1724
1725
    #[test]
1726
1
    fn check_not_in_left() {
1727
1
        let left = vec![Column::new("b", 0)];
1728
1
        let right = vec![Column::new("a", 0)];
1729
1
        let on = &[(
1730
1
            Arc::new(Column::new("a", 0)) as _,
1731
1
            Arc::new(Column::new("a", 0)) as _,
1732
1
        )];
1733
1
1734
1
        assert!(check(&left, &right, on).is_err());
1735
1
    }
1736
1737
    #[test]
1738
1
    fn check_collision() {
1739
1
        // column "a" would appear both in left and right
1740
1
        let left = vec![Column::new("a", 0), Column::new("c", 1)];
1741
1
        let right = vec![Column::new("a", 0), Column::new("b", 1)];
1742
1
        let on = &[(
1743
1
            Arc::new(Column::new("a", 0)) as _,
1744
1
            Arc::new(Column::new("b", 1)) as _,
1745
1
        )];
1746
1
1747
1
        assert!(check(&left, &right, on).is_ok());
1748
1
    }
1749
1750
    #[test]
1751
1
    fn check_in_right() {
1752
1
        let left = vec![Column::new("a", 0), Column::new("c", 1)];
1753
1
        let right = vec![Column::new("b", 0)];
1754
1
        let on = &[(
1755
1
            Arc::new(Column::new("a", 0)) as _,
1756
1
            Arc::new(Column::new("b", 0)) as _,
1757
1
        )];
1758
1
1759
1
        assert!(check(&left, &right, on).is_ok());
1760
1
    }
1761
1762
    #[test]
1763
1
    fn test_join_schema() -> Result<()> {
1764
1
        let a = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
1765
1
        let a_nulls = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1766
1
        let b = Schema::new(vec![Field::new("b", DataType::Int32, false)]);
1767
1
        let b_nulls = Schema::new(vec![Field::new("b", DataType::Int32, true)]);
1768
1
1769
1
        let cases = vec![
1770
1
            (&a, &b, JoinType::Inner, &a, &b),
1771
1
            (&a, &b_nulls, JoinType::Inner, &a, &b_nulls),
1772
1
            (&a_nulls, &b, JoinType::Inner, &a_nulls, &b),
1773
1
            (&a_nulls, &b_nulls, JoinType::Inner, &a_nulls, &b_nulls),
1774
1
            // right input of a `LEFT` join can be null, regardless of input nullness
1775
1
            (&a, &b, JoinType::Left, &a, &b_nulls),
1776
1
            (&a, &b_nulls, JoinType::Left, &a, &b_nulls),
1777
1
            (&a_nulls, &b, JoinType::Left, &a_nulls, &b_nulls),
1778
1
            (&a_nulls, &b_nulls, JoinType::Left, &a_nulls, &b_nulls),
1779
1
            // left input of a `RIGHT` join can be null, regardless of input nullness
1780
1
            (&a, &b, JoinType::Right, &a_nulls, &b),
1781
1
            (&a, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
1782
1
            (&a_nulls, &b, JoinType::Right, &a_nulls, &b),
1783
1
            (&a_nulls, &b_nulls, JoinType::Right, &a_nulls, &b_nulls),
1784
1
            // Either input of a `FULL` join can be null
1785
1
            (&a, &b, JoinType::Full, &a_nulls, &b_nulls),
1786
1
            (&a, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
1787
1
            (&a_nulls, &b, JoinType::Full, &a_nulls, &b_nulls),
1788
1
            (&a_nulls, &b_nulls, JoinType::Full, &a_nulls, &b_nulls),
1789
1
        ];
1790
1791
17
        for (
left_in, right_in, join_type, left_out, right_out16
) in cases {
1792
16
            let (schema, _) = build_join_schema(left_in, right_in, &join_type);
1793
16
1794
16
            let expected_fields = left_out
1795
16
                .fields()
1796
16
                .iter()
1797
16
                .cloned()
1798
16
                .chain(right_out.fields().iter().cloned())
1799
16
                .collect::<Fields>();
1800
16
1801
16
            let expected_schema = Schema::new(expected_fields);
1802
16
            assert_eq!(
1803
                schema,
1804
                expected_schema,
1805
0
                "Mismatch with left_in={}:{}, right_in={}:{}, join_type={:?}",
1806
0
                left_in.fields()[0].name(),
1807
0
                left_in.fields()[0].is_nullable(),
1808
0
                right_in.fields()[0].name(),
1809
0
                right_in.fields()[0].is_nullable(),
1810
                join_type
1811
            );
1812
        }
1813
1814
1
        Ok(())
1815
1
    }
1816
1817
58
    fn create_stats(
1818
58
        num_rows: Option<usize>,
1819
58
        column_stats: Vec<ColumnStatistics>,
1820
58
        is_exact: bool,
1821
58
    ) -> Statistics {
1822
58
        Statistics {
1823
58
            num_rows: if is_exact {
1824
8
                num_rows.map(Precision::Exact)
1825
            } else {
1826
50
                num_rows.map(Precision::Inexact)
1827
            }
1828
58
            .unwrap_or(Precision::Absent),
1829
58
            column_statistics: column_stats,
1830
58
            total_byte_size: Precision::Absent,
1831
58
        }
1832
58
    }
1833
1834
83
    fn create_column_stats(
1835
83
        min: Precision<i64>,
1836
83
        max: Precision<i64>,
1837
83
        distinct_count: Precision<usize>,
1838
83
        null_count: Precision<usize>,
1839
83
    ) -> ColumnStatistics {
1840
83
        ColumnStatistics {
1841
83
            distinct_count,
1842
83
            min_value: min.map(ScalarValue::from),
1843
83
            max_value: max.map(ScalarValue::from),
1844
83
            null_count,
1845
83
        }
1846
83
    }
1847
1848
    type PartialStats = (
1849
        usize,
1850
        Precision<i64>,
1851
        Precision<i64>,
1852
        Precision<usize>,
1853
        Precision<usize>,
1854
    );
1855
1856
    // This is mainly for validating the all edge cases of the estimation, but
1857
    // more advanced (and real world test cases) are below where we need some control
1858
    // over the expected output (since it depends on join type to join type).
1859
    #[test]
1860
1
    fn test_inner_join_cardinality_single_column() -> Result<()> {
1861
1
        let cases: Vec<(PartialStats, PartialStats, Option<Precision<usize>>)> = vec![
1862
1
            // ------------------------------------------------
1863
1
            // | left(rows, min, max, distinct, null_count),  |
1864
1
            // | right(rows, min, max, distinct, null_count), |
1865
1
            // | expected,                                    |
1866
1
            // ------------------------------------------------
1867
1
1868
1
            // Cardinality computation
1869
1
            // =======================
1870
1
            //
1871
1
            // distinct(left) == NaN, distinct(right) == NaN
1872
1
            (
1873
1
                (10, Inexact(1), Inexact(10), Absent, Absent),
1874
1
                (10, Inexact(1), Inexact(10), Absent, Absent),
1875
1
                Some(Inexact(10)),
1876
1
            ),
1877
1
            // range(left) > range(right)
1878
1
            (
1879
1
                (10, Inexact(6), Inexact(10), Absent, Absent),
1880
1
                (10, Inexact(8), Inexact(10), Absent, Absent),
1881
1
                Some(Inexact(20)),
1882
1
            ),
1883
1
            // range(right) > range(left)
1884
1
            (
1885
1
                (10, Inexact(8), Inexact(10), Absent, Absent),
1886
1
                (10, Inexact(6), Inexact(10), Absent, Absent),
1887
1
                Some(Inexact(20)),
1888
1
            ),
1889
1
            // range(left) > len(left), range(right) > len(right)
1890
1
            (
1891
1
                (10, Inexact(1), Inexact(15), Absent, Absent),
1892
1
                (20, Inexact(1), Inexact(40), Absent, Absent),
1893
1
                Some(Inexact(10)),
1894
1
            ),
1895
1
            // When we have distinct count.
1896
1
            (
1897
1
                (10, Inexact(1), Inexact(10), Inexact(10), Absent),
1898
1
                (10, Inexact(1), Inexact(10), Inexact(10), Absent),
1899
1
                Some(Inexact(10)),
1900
1
            ),
1901
1
            // distinct(left) > distinct(right)
1902
1
            (
1903
1
                (10, Inexact(1), Inexact(10), Inexact(5), Absent),
1904
1
                (10, Inexact(1), Inexact(10), Inexact(2), Absent),
1905
1
                Some(Inexact(20)),
1906
1
            ),
1907
1
            // distinct(right) > distinct(left)
1908
1
            (
1909
1
                (10, Inexact(1), Inexact(10), Inexact(2), Absent),
1910
1
                (10, Inexact(1), Inexact(10), Inexact(5), Absent),
1911
1
                Some(Inexact(20)),
1912
1
            ),
1913
1
            // min(left) < 0 (range(left) > range(right))
1914
1
            (
1915
1
                (10, Inexact(-5), Inexact(5), Absent, Absent),
1916
1
                (10, Inexact(1), Inexact(5), Absent, Absent),
1917
1
                Some(Inexact(10)),
1918
1
            ),
1919
1
            // min(right) < 0, max(right) < 0 (range(right) > range(left))
1920
1
            (
1921
1
                (10, Inexact(-25), Inexact(-20), Absent, Absent),
1922
1
                (10, Inexact(-25), Inexact(-15), Absent, Absent),
1923
1
                Some(Inexact(10)),
1924
1
            ),
1925
1
            // range(left) < 0, range(right) >= 0
1926
1
            // (there isn't a case where both left and right ranges are negative
1927
1
            //  so one of them is always going to work, this just proves negative
1928
1
            //  ranges with bigger absolute values are not are not accidentally used).
1929
1
            (
1930
1
                (10, Inexact(-10), Inexact(0), Absent, Absent),
1931
1
                (10, Inexact(0), Inexact(10), Inexact(5), Absent),
1932
1
                Some(Inexact(10)),
1933
1
            ),
1934
1
            // range(left) = 1, range(right) = 1
1935
1
            (
1936
1
                (10, Inexact(1), Inexact(1), Absent, Absent),
1937
1
                (10, Inexact(1), Inexact(1), Absent, Absent),
1938
1
                Some(Inexact(100)),
1939
1
            ),
1940
1
            //
1941
1
            // Edge cases
1942
1
            // ==========
1943
1
            //
1944
1
            // No column level stats.
1945
1
            (
1946
1
                (10, Absent, Absent, Absent, Absent),
1947
1
                (10, Absent, Absent, Absent, Absent),
1948
1
                None,
1949
1
            ),
1950
1
            // No min or max (or both).
1951
1
            (
1952
1
                (10, Absent, Absent, Inexact(3), Absent),
1953
1
                (10, Absent, Absent, Inexact(3), Absent),
1954
1
                None,
1955
1
            ),
1956
1
            (
1957
1
                (10, Inexact(2), Absent, Inexact(3), Absent),
1958
1
                (10, Absent, Inexact(5), Inexact(3), Absent),
1959
1
                None,
1960
1
            ),
1961
1
            (
1962
1
                (10, Absent, Inexact(3), Inexact(3), Absent),
1963
1
                (10, Inexact(1), Absent, Inexact(3), Absent),
1964
1
                None,
1965
1
            ),
1966
1
            (
1967
1
                (10, Absent, Inexact(3), Absent, Absent),
1968
1
                (10, Inexact(1), Absent, Absent, Absent),
1969
1
                None,
1970
1
            ),
1971
1
            // Non overlapping min/max (when exact=False).
1972
1
            (
1973
1
                (10, Absent, Inexact(4), Absent, Absent),
1974
1
                (10, Inexact(5), Absent, Absent, Absent),
1975
1
                Some(Inexact(0)),
1976
1
            ),
1977
1
            (
1978
1
                (10, Inexact(0), Inexact(10), Absent, Absent),
1979
1
                (10, Inexact(11), Inexact(20), Absent, Absent),
1980
1
                Some(Inexact(0)),
1981
1
            ),
1982
1
            (
1983
1
                (10, Inexact(11), Inexact(20), Absent, Absent),
1984
1
                (10, Inexact(0), Inexact(10), Absent, Absent),
1985
1
                Some(Inexact(0)),
1986
1
            ),
1987
1
            // distinct(left) = 0, distinct(right) = 0
1988
1
            (
1989
1
                (10, Inexact(1), Inexact(10), Inexact(0), Absent),
1990
1
                (10, Inexact(1), Inexact(10), Inexact(0), Absent),
1991
1
                None,
1992
1
            ),
1993
1
            // Inexact row count < exact null count with absent distinct count
1994
1
            (
1995
1
                (0, Inexact(1), Inexact(10), Absent, Exact(5)),
1996
1
                (10, Inexact(1), Inexact(10), Absent, Absent),
1997
1
                Some(Inexact(0)),
1998
1
            ),
1999
1
        ];
2000
2001
22
        for (
left_info, right_info, expected_cardinality21
) in cases {
2002
21
            let left_num_rows = left_info.0;
2003
21
            let left_col_stats = vec![create_column_stats(
2004
21
                left_info.1,
2005
21
                left_info.2,
2006
21
                left_info.3,
2007
21
                left_info.4,
2008
21
            )];
2009
21
2010
21
            let right_num_rows = right_info.0;
2011
21
            let right_col_stats = vec![create_column_stats(
2012
21
                right_info.1,
2013
21
                right_info.2,
2014
21
                right_info.3,
2015
21
                right_info.4,
2016
21
            )];
2017
21
2018
21
            assert_eq!(
2019
21
                estimate_inner_join_cardinality(
2020
21
                    Statistics {
2021
21
                        num_rows: Inexact(left_num_rows),
2022
21
                        total_byte_size: Absent,
2023
21
                        column_statistics: left_col_stats.clone(),
2024
21
                    },
2025
21
                    Statistics {
2026
21
                        num_rows: Inexact(right_num_rows),
2027
21
                        total_byte_size: Absent,
2028
21
                        column_statistics: right_col_stats.clone(),
2029
21
                    },
2030
21
                ),
2031
21
                expected_cardinality.clone()
2032
21
            );
2033
2034
            // We should also be able to use join_cardinality to get the same results
2035
21
            let join_type = JoinType::Inner;
2036
21
            let join_on = vec![(
2037
21
                Arc::new(Column::new("a", 0)) as _,
2038
21
                Arc::new(Column::new("b", 0)) as _,
2039
21
            )];
2040
21
            let partial_join_stats = estimate_join_cardinality(
2041
21
                &join_type,
2042
21
                create_stats(Some(left_num_rows), left_col_stats.clone(), false),
2043
21
                create_stats(Some(right_num_rows), right_col_stats.clone(), false),
2044
21
                &join_on,
2045
21
            );
2046
21
2047
21
            assert_eq!(
2048
21
                partial_join_stats.clone().map(|s| 
Inexact(s.num_rows)15
),
2049
21
                expected_cardinality.clone()
2050
21
            );
2051
21
            assert_eq!(
2052
21
                partial_join_stats.map(|s| 
s.column_statistics15
),
2053
21
                expected_cardinality.map(|_| 
[left_col_stats, right_col_stats].concat()15
)
2054
21
            );
2055
        }
2056
1
        Ok(())
2057
1
    }
2058
2059
    #[test]
2060
1
    fn test_inner_join_cardinality_multiple_column() -> Result<()> {
2061
1
        let left_col_stats = vec![
2062
1
            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2063
1
            create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
2064
1
        ];
2065
1
2066
1
        let right_col_stats = vec![
2067
1
            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2068
1
            create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
2069
1
        ];
2070
1
2071
1
        // We have statistics about 4 columns, where the highest distinct
2072
1
        // count is 200, so we are going to pick it.
2073
1
        assert_eq!(
2074
1
            estimate_inner_join_cardinality(
2075
1
                Statistics {
2076
1
                    num_rows: Precision::Inexact(400),
2077
1
                    total_byte_size: Precision::Absent,
2078
1
                    column_statistics: left_col_stats,
2079
1
                },
2080
1
                Statistics {
2081
1
                    num_rows: Precision::Inexact(400),
2082
1
                    total_byte_size: Precision::Absent,
2083
1
                    column_statistics: right_col_stats,
2084
1
                },
2085
1
            ),
2086
1
            Some(Precision::Inexact((400 * 400) / 200))
2087
1
        );
2088
1
        Ok(())
2089
1
    }
2090
2091
    #[test]
2092
1
    fn test_inner_join_cardinality_decimal_range() -> Result<()> {
2093
1
        let left_col_stats = vec![ColumnStatistics {
2094
1
            distinct_count: Precision::Absent,
2095
1
            min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)),
2096
1
            max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)),
2097
1
            ..Default::default()
2098
1
        }];
2099
1
2100
1
        let right_col_stats = vec![ColumnStatistics {
2101
1
            distinct_count: Precision::Absent,
2102
1
            min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)),
2103
1
            max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)),
2104
1
            ..Default::default()
2105
1
        }];
2106
1
2107
1
        assert_eq!(
2108
1
            estimate_inner_join_cardinality(
2109
1
                Statistics {
2110
1
                    num_rows: Precision::Inexact(100),
2111
1
                    total_byte_size: Precision::Absent,
2112
1
                    column_statistics: left_col_stats,
2113
1
                },
2114
1
                Statistics {
2115
1
                    num_rows: Precision::Inexact(100),
2116
1
                    total_byte_size: Precision::Absent,
2117
1
                    column_statistics: right_col_stats,
2118
1
                },
2119
1
            ),
2120
1
            Some(Precision::Inexact(100))
2121
1
        );
2122
1
        Ok(())
2123
1
    }
2124
2125
    #[test]
2126
1
    fn test_join_cardinality() -> Result<()> {
2127
1
        // Left table (rows=1000)
2128
1
        //   a: min=0, max=100, distinct=100
2129
1
        //   b: min=0, max=500, distinct=500
2130
1
        //   x: min=1000, max=10000, distinct=None
2131
1
        //
2132
1
        // Right table (rows=2000)
2133
1
        //   c: min=0, max=100, distinct=50
2134
1
        //   d: min=0, max=2000, distinct=2500 (how? some inexact statistics)
2135
1
        //   y: min=0, max=100, distinct=None
2136
1
        //
2137
1
        // Join on a=c, b=d (ignore x/y)
2138
1
        let cases = vec![
2139
1
            (JoinType::Inner, 800),
2140
1
            (JoinType::Left, 1000),
2141
1
            (JoinType::Right, 2000),
2142
1
            (JoinType::Full, 2200),
2143
1
        ];
2144
1
2145
1
        let left_col_stats = vec![
2146
1
            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2147
1
            create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2148
1
            create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2149
1
        ];
2150
1
2151
1
        let right_col_stats = vec![
2152
1
            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2153
1
            create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2154
1
            create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2155
1
        ];
2156
2157
5
        for (
join_type, expected_num_rows4
) in cases {
2158
4
            let join_on = vec![
2159
4
                (
2160
4
                    Arc::new(Column::new("a", 0)) as _,
2161
4
                    Arc::new(Column::new("c", 0)) as _,
2162
4
                ),
2163
4
                (
2164
4
                    Arc::new(Column::new("b", 1)) as _,
2165
4
                    Arc::new(Column::new("d", 1)) as _,
2166
4
                ),
2167
4
            ];
2168
4
2169
4
            let partial_join_stats = estimate_join_cardinality(
2170
4
                &join_type,
2171
4
                create_stats(Some(1000), left_col_stats.clone(), false),
2172
4
                create_stats(Some(2000), right_col_stats.clone(), false),
2173
4
                &join_on,
2174
4
            )
2175
4
            .unwrap();
2176
4
            assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2177
4
            assert_eq!(
2178
4
                partial_join_stats.column_statistics,
2179
4
                [left_col_stats.clone(), right_col_stats.clone()].concat()
2180
4
            );
2181
        }
2182
2183
1
        Ok(())
2184
1
    }
2185
2186
    #[test]
2187
1
    fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> {
2188
1
        // Left table (rows=1000)
2189
1
        //   a: min=0, max=100, distinct=100
2190
1
        //   b: min=0, max=500, distinct=500
2191
1
        //   x: min=1000, max=10000, distinct=None
2192
1
        //
2193
1
        // Right table (rows=2000)
2194
1
        //   c: min=0, max=100, distinct=50
2195
1
        //   d: min=0, max=2000, distinct=2500 (how? some inexact statistics)
2196
1
        //   y: min=0, max=100, distinct=None
2197
1
        //
2198
1
        // Join on a=c, x=y (ignores b/d) where x and y does not intersect
2199
1
2200
1
        let left_col_stats = vec![
2201
1
            create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
2202
1
            create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
2203
1
            create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
2204
1
        ];
2205
1
2206
1
        let right_col_stats = vec![
2207
1
            create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
2208
1
            create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
2209
1
            create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
2210
1
        ];
2211
1
2212
1
        let join_on = vec![
2213
1
            (
2214
1
                Arc::new(Column::new("a", 0)) as _,
2215
1
                Arc::new(Column::new("c", 0)) as _,
2216
1
            ),
2217
1
            (
2218
1
                Arc::new(Column::new("x", 2)) as _,
2219
1
                Arc::new(Column::new("y", 2)) as _,
2220
1
            ),
2221
1
        ];
2222
1
2223
1
        let cases = vec![
2224
1
            // Join type, expected cardinality
2225
1
            //
2226
1
            // When an inner join is disjoint, that means it won't
2227
1
            // produce any rows.
2228
1
            (JoinType::Inner, 0),
2229
1
            // But left/right outer joins will produce at least
2230
1
            // the amount of rows from the left/right side.
2231
1
            (JoinType::Left, 1000),
2232
1
            (JoinType::Right, 2000),
2233
1
            // And a full outer join will produce at least the combination
2234
1
            // of the rows above (minus the cardinality of the inner join, which
2235
1
            // is 0).
2236
1
            (JoinType::Full, 3000),
2237
1
        ];
2238
2239
5
        for (
join_type, expected_num_rows4
) in cases {
2240
4
            let partial_join_stats = estimate_join_cardinality(
2241
4
                &join_type,
2242
4
                create_stats(Some(1000), left_col_stats.clone(), true),
2243
4
                create_stats(Some(2000), right_col_stats.clone(), true),
2244
4
                &join_on,
2245
4
            )
2246
4
            .unwrap();
2247
4
            assert_eq!(partial_join_stats.num_rows, expected_num_rows);
2248
4
            assert_eq!(
2249
4
                partial_join_stats.column_statistics,
2250
4
                [left_col_stats.clone(), right_col_stats.clone()].concat()
2251
4
            );
2252
        }
2253
2254
1
        Ok(())
2255
1
    }
2256
2257
    #[test]
2258
1
    fn test_anti_semi_join_cardinality() -> Result<()> {
2259
1
        let cases: Vec<(JoinType, PartialStats, PartialStats, Option<usize>)> = vec![
2260
1
            // ------------------------------------------------
2261
1
            // | join_type ,                                   |
2262
1
            // | left(rows, min, max, distinct, null_count), |
2263
1
            // | right(rows, min, max, distinct, null_count), |
2264
1
            // | expected,                                    |
2265
1
            // ------------------------------------------------
2266
1
2267
1
            // Cardinality computation
2268
1
            // =======================
2269
1
            (
2270
1
                JoinType::LeftSemi,
2271
1
                (50, Inexact(10), Inexact(20), Absent, Absent),
2272
1
                (10, Inexact(15), Inexact(25), Absent, Absent),
2273
1
                Some(50),
2274
1
            ),
2275
1
            (
2276
1
                JoinType::RightSemi,
2277
1
                (50, Inexact(10), Inexact(20), Absent, Absent),
2278
1
                (10, Inexact(15), Inexact(25), Absent, Absent),
2279
1
                Some(10),
2280
1
            ),
2281
1
            (
2282
1
                JoinType::LeftSemi,
2283
1
                (10, Absent, Absent, Absent, Absent),
2284
1
                (50, Absent, Absent, Absent, Absent),
2285
1
                Some(10),
2286
1
            ),
2287
1
            (
2288
1
                JoinType::LeftSemi,
2289
1
                (50, Inexact(10), Inexact(20), Absent, Absent),
2290
1
                (10, Inexact(30), Inexact(40), Absent, Absent),
2291
1
                Some(0),
2292
1
            ),
2293
1
            (
2294
1
                JoinType::LeftSemi,
2295
1
                (50, Inexact(10), Absent, Absent, Absent),
2296
1
                (10, Absent, Inexact(5), Absent, Absent),
2297
1
                Some(0),
2298
1
            ),
2299
1
            (
2300
1
                JoinType::LeftSemi,
2301
1
                (50, Absent, Inexact(20), Absent, Absent),
2302
1
                (10, Inexact(30), Absent, Absent, Absent),
2303
1
                Some(0),
2304
1
            ),
2305
1
            (
2306
1
                JoinType::LeftAnti,
2307
1
                (50, Inexact(10), Inexact(20), Absent, Absent),
2308
1
                (10, Inexact(15), Inexact(25), Absent, Absent),
2309
1
                Some(50),
2310
1
            ),
2311
1
            (
2312
1
                JoinType::RightAnti,
2313
1
                (50, Inexact(10), Inexact(20), Absent, Absent),
2314
1
                (10, Inexact(15), Inexact(25), Absent, Absent),
2315
1
                Some(10),
2316
1
            ),
2317
1
            (
2318
1
                JoinType::LeftAnti,
2319
1
                (10, Absent, Absent, Absent, Absent),
2320
1
                (50, Absent, Absent, Absent, Absent),
2321
1
                Some(10),
2322
1
            ),
2323
1
            (
2324
1
                JoinType::LeftAnti,
2325
1
                (50, Inexact(10), Inexact(20), Absent, Absent),
2326
1
                (10, Inexact(30), Inexact(40), Absent, Absent),
2327
1
                Some(50),
2328
1
            ),
2329
1
            (
2330
1
                JoinType::LeftAnti,
2331
1
                (50, Inexact(10), Absent, Absent, Absent),
2332
1
                (10, Absent, Inexact(5), Absent, Absent),
2333
1
                Some(50),
2334
1
            ),
2335
1
            (
2336
1
                JoinType::LeftAnti,
2337
1
                (50, Absent, Inexact(20), Absent, Absent),
2338
1
                (10, Inexact(30), Absent, Absent, Absent),
2339
1
                Some(50),
2340
1
            ),
2341
1
        ];
2342
1
2343
1
        let join_on = vec![(
2344
1
            Arc::new(Column::new("l_col", 0)) as _,
2345
1
            Arc::new(Column::new("r_col", 0)) as _,
2346
1
        )];
2347
2348
13
        for (
join_type, outer_info, inner_info, expected12
) in cases {
2349
12
            let outer_num_rows = outer_info.0;
2350
12
            let outer_col_stats = vec![create_column_stats(
2351
12
                outer_info.1,
2352
12
                outer_info.2,
2353
12
                outer_info.3,
2354
12
                outer_info.4,
2355
12
            )];
2356
12
2357
12
            let inner_num_rows = inner_info.0;
2358
12
            let inner_col_stats = vec![create_column_stats(
2359
12
                inner_info.1,
2360
12
                inner_info.2,
2361
12
                inner_info.3,
2362
12
                inner_info.4,
2363
12
            )];
2364
12
2365
12
            let output_cardinality = estimate_join_cardinality(
2366
12
                &join_type,
2367
12
                Statistics {
2368
12
                    num_rows: Inexact(outer_num_rows),
2369
12
                    total_byte_size: Absent,
2370
12
                    column_statistics: outer_col_stats,
2371
12
                },
2372
12
                Statistics {
2373
12
                    num_rows: Inexact(inner_num_rows),
2374
12
                    total_byte_size: Absent,
2375
12
                    column_statistics: inner_col_stats,
2376
12
                },
2377
12
                &join_on,
2378
12
            )
2379
12
            .map(|cardinality| cardinality.num_rows);
2380
12
2381
12
            assert_eq!(
2382
                output_cardinality, expected,
2383
0
                "failure for join_type: {}",
2384
                join_type
2385
            );
2386
        }
2387
2388
1
        Ok(())
2389
1
    }
2390
2391
    #[test]
2392
1
    fn test_semi_join_cardinality_absent_rows() -> Result<()> {
2393
1
        let dummy_column_stats =
2394
1
            vec![create_column_stats(Absent, Absent, Absent, Absent)];
2395
1
        let join_on = vec![(
2396
1
            Arc::new(Column::new("l_col", 0)) as _,
2397
1
            Arc::new(Column::new("r_col", 0)) as _,
2398
1
        )];
2399
1
2400
1
        let absent_outer_estimation = estimate_join_cardinality(
2401
1
            &JoinType::LeftSemi,
2402
1
            Statistics {
2403
1
                num_rows: Absent,
2404
1
                total_byte_size: Absent,
2405
1
                column_statistics: dummy_column_stats.clone(),
2406
1
            },
2407
1
            Statistics {
2408
1
                num_rows: Exact(10),
2409
1
                total_byte_size: Absent,
2410
1
                column_statistics: dummy_column_stats.clone(),
2411
1
            },
2412
1
            &join_on,
2413
1
        );
2414
1
        assert!(
2415
1
            absent_outer_estimation.is_none(),
2416
0
            "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows"
2417
        );
2418
2419
1
        let absent_inner_estimation = estimate_join_cardinality(
2420
1
            &JoinType::LeftSemi,
2421
1
            Statistics {
2422
1
                num_rows: Inexact(500),
2423
1
                total_byte_size: Absent,
2424
1
                column_statistics: dummy_column_stats.clone(),
2425
1
            },
2426
1
            Statistics {
2427
1
                num_rows: Absent,
2428
1
                total_byte_size: Absent,
2429
1
                column_statistics: dummy_column_stats.clone(),
2430
1
            },
2431
1
            &join_on,
2432
1
        ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows");
2433
1
2434
1
        assert_eq!(absent_inner_estimation.num_rows, 500, 
"Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows"0
);
2435
2436
1
        let absent_inner_estimation = estimate_join_cardinality(
2437
1
            &JoinType::LeftSemi,
2438
1
            Statistics {
2439
1
                num_rows: Absent,
2440
1
                total_byte_size: Absent,
2441
1
                column_statistics: dummy_column_stats.clone(),
2442
1
            },
2443
1
            Statistics {
2444
1
                num_rows: Absent,
2445
1
                total_byte_size: Absent,
2446
1
                column_statistics: dummy_column_stats,
2447
1
            },
2448
1
            &join_on,
2449
1
        );
2450
1
        assert!(absent_inner_estimation.is_none(), 
"Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows"0
);
2451
2452
1
        Ok(())
2453
1
    }
2454
2455
    #[test]
2456
1
    fn test_calculate_join_output_ordering() -> Result<()> {
2457
1
        let options = SortOptions::default();
2458
1
        let left_ordering = vec![
2459
1
            PhysicalSortExpr {
2460
1
                expr: Arc::new(Column::new("a", 0)),
2461
1
                options,
2462
1
            },
2463
1
            PhysicalSortExpr {
2464
1
                expr: Arc::new(Column::new("c", 2)),
2465
1
                options,
2466
1
            },
2467
1
            PhysicalSortExpr {
2468
1
                expr: Arc::new(Column::new("d", 3)),
2469
1
                options,
2470
1
            },
2471
1
        ];
2472
1
        let right_ordering = vec![
2473
1
            PhysicalSortExpr {
2474
1
                expr: Arc::new(Column::new("z", 2)),
2475
1
                options,
2476
1
            },
2477
1
            PhysicalSortExpr {
2478
1
                expr: Arc::new(Column::new("y", 1)),
2479
1
                options,
2480
1
            },
2481
1
        ];
2482
1
        let join_type = JoinType::Inner;
2483
1
        let on_columns = [(
2484
1
            Arc::new(Column::new("b", 1)) as _,
2485
1
            Arc::new(Column::new("x", 0)) as _,
2486
1
        )];
2487
1
        let left_columns_len = 5;
2488
1
        let maintains_input_orders = [[true, false], [false, true]];
2489
1
        let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
2490
1
2491
1
        let expected = [
2492
1
            Some(vec![
2493
1
                PhysicalSortExpr {
2494
1
                    expr: Arc::new(Column::new("a", 0)),
2495
1
                    options,
2496
1
                },
2497
1
                PhysicalSortExpr {
2498
1
                    expr: Arc::new(Column::new("c", 2)),
2499
1
                    options,
2500
1
                },
2501
1
                PhysicalSortExpr {
2502
1
                    expr: Arc::new(Column::new("d", 3)),
2503
1
                    options,
2504
1
                },
2505
1
                PhysicalSortExpr {
2506
1
                    expr: Arc::new(Column::new("z", 7)),
2507
1
                    options,
2508
1
                },
2509
1
                PhysicalSortExpr {
2510
1
                    expr: Arc::new(Column::new("y", 6)),
2511
1
                    options,
2512
1
                },
2513
1
            ]),
2514
1
            Some(vec![
2515
1
                PhysicalSortExpr {
2516
1
                    expr: Arc::new(Column::new("z", 7)),
2517
1
                    options,
2518
1
                },
2519
1
                PhysicalSortExpr {
2520
1
                    expr: Arc::new(Column::new("y", 6)),
2521
1
                    options,
2522
1
                },
2523
1
                PhysicalSortExpr {
2524
1
                    expr: Arc::new(Column::new("a", 0)),
2525
1
                    options,
2526
1
                },
2527
1
                PhysicalSortExpr {
2528
1
                    expr: Arc::new(Column::new("c", 2)),
2529
1
                    options,
2530
1
                },
2531
1
                PhysicalSortExpr {
2532
1
                    expr: Arc::new(Column::new("d", 3)),
2533
1
                    options,
2534
1
                },
2535
1
            ]),
2536
1
        ];
2537
2538
2
        for (i, (maintains_input_order, probe_side)) in
2539
1
            maintains_input_orders.iter().zip(probe_sides).enumerate()
2540
        {
2541
2
            assert_eq!(
2542
2
                calculate_join_output_ordering(
2543
2
                    &left_ordering,
2544
2
                    &right_ordering,
2545
2
                    join_type,
2546
2
                    &on_columns,
2547
2
                    left_columns_len,
2548
2
                    maintains_input_order,
2549
2
                    probe_side,
2550
2
                ),
2551
2
                expected[i]
2552
2
            );
2553
        }
2554
2555
1
        Ok(())
2556
1
    }
2557
}