/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/hash_join.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`HashJoinExec`] Partitioned Hash Join Operator |
19 | | |
20 | | use std::fmt; |
21 | | use std::sync::atomic::{AtomicUsize, Ordering}; |
22 | | use std::sync::Arc; |
23 | | use std::task::Poll; |
24 | | use std::{any::Any, vec}; |
25 | | |
26 | | use super::utils::asymmetric_join_output_partitioning; |
27 | | use super::{ |
28 | | utils::{OnceAsync, OnceFut}, |
29 | | PartitionMode, |
30 | | }; |
31 | | use crate::ExecutionPlanProperties; |
32 | | use crate::{ |
33 | | coalesce_partitions::CoalescePartitionsExec, |
34 | | common::can_project, |
35 | | execution_mode_from_children, handle_state, |
36 | | hash_utils::create_hashes, |
37 | | joins::utils::{ |
38 | | adjust_indices_by_join_type, apply_join_filter_to_indices, |
39 | | build_batch_from_indices, build_join_schema, check_join_is_valid, |
40 | | estimate_join_statistics, get_final_indices_from_bit_map, |
41 | | need_produce_result_in_final, symmetric_join_output_partitioning, |
42 | | BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMap, JoinHashMapOffset, |
43 | | JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, |
44 | | }, |
45 | | metrics::{ExecutionPlanMetricsSet, MetricsSet}, |
46 | | DisplayAs, DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, |
47 | | Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, |
48 | | Statistics, |
49 | | }; |
50 | | |
51 | | use arrow::array::{ |
52 | | Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array, |
53 | | }; |
54 | | use arrow::compute::kernels::cmp::{eq, not_distinct}; |
55 | | use arrow::compute::{and, concat_batches, take, FilterBuilder}; |
56 | | use arrow::datatypes::{Schema, SchemaRef}; |
57 | | use arrow::record_batch::RecordBatch; |
58 | | use arrow::util::bit_util; |
59 | | use arrow_array::cast::downcast_array; |
60 | | use arrow_schema::ArrowError; |
61 | | use datafusion_common::utils::memory::estimate_memory_size; |
62 | | use datafusion_common::{ |
63 | | internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, |
64 | | JoinSide, JoinType, Result, |
65 | | }; |
66 | | use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; |
67 | | use datafusion_execution::TaskContext; |
68 | | use datafusion_physical_expr::equivalence::{ |
69 | | join_equivalence_properties, ProjectionMapping, |
70 | | }; |
71 | | use datafusion_physical_expr::PhysicalExprRef; |
72 | | |
73 | | use ahash::RandomState; |
74 | | use datafusion_expr::Operator; |
75 | | use datafusion_physical_expr_common::datum::compare_op_for_nested; |
76 | | use futures::{ready, Stream, StreamExt, TryStreamExt}; |
77 | | use parking_lot::Mutex; |
78 | | |
79 | | type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>; |
80 | | |
81 | | /// HashTable and input data for the left (build side) of a join |
82 | | struct JoinLeftData { |
83 | | /// The hash table with indices into `batch` |
84 | | hash_map: JoinHashMap, |
85 | | /// The input rows for the build side |
86 | | batch: RecordBatch, |
87 | | /// Shared bitmap builder for visited left indices |
88 | | visited_indices_bitmap: Mutex<BooleanBufferBuilder>, |
89 | | /// Counter of running probe-threads, potentially |
90 | | /// able to update `visited_indices_bitmap` |
91 | | probe_threads_counter: AtomicUsize, |
92 | | /// Memory reservation that tracks memory used by `hash_map` hash table |
93 | | /// `batch`. Cleared on drop. |
94 | | #[allow(dead_code)] |
95 | | reservation: MemoryReservation, |
96 | | } |
97 | | |
98 | | impl JoinLeftData { |
99 | | /// Create a new `JoinLeftData` from its parts |
100 | 1.71k | fn new( |
101 | 1.71k | hash_map: JoinHashMap, |
102 | 1.71k | batch: RecordBatch, |
103 | 1.71k | visited_indices_bitmap: SharedBitmapBuilder, |
104 | 1.71k | probe_threads_counter: AtomicUsize, |
105 | 1.71k | reservation: MemoryReservation, |
106 | 1.71k | ) -> Self { |
107 | 1.71k | Self { |
108 | 1.71k | hash_map, |
109 | 1.71k | batch, |
110 | 1.71k | visited_indices_bitmap, |
111 | 1.71k | probe_threads_counter, |
112 | 1.71k | reservation, |
113 | 1.71k | } |
114 | 1.71k | } |
115 | | |
116 | | /// return a reference to the hash map |
117 | 5.16k | fn hash_map(&self) -> &JoinHashMap { |
118 | 5.16k | &self.hash_map |
119 | 5.16k | } |
120 | | |
121 | | /// returns a reference to the build side batch |
122 | 15.3k | fn batch(&self) -> &RecordBatch { |
123 | 15.3k | &self.batch |
124 | 15.3k | } |
125 | | |
126 | | /// returns a reference to the visited indices bitmap |
127 | 3.45k | fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { |
128 | 3.45k | &self.visited_indices_bitmap |
129 | 3.45k | } |
130 | | |
131 | | /// Decrements the counter of running threads, and returns `true` |
132 | | /// if caller is the last running thread |
133 | 862 | fn report_probe_completed(&self) -> bool { |
134 | 862 | self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 |
135 | 862 | } |
136 | | } |
137 | | |
138 | | /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple |
139 | | /// partitions using a hash table and an optional filter list to apply post |
140 | | /// join. |
141 | | /// |
142 | | /// # Join Expressions |
143 | | /// |
144 | | /// This implementation is optimized for evaluating eqijoin predicates ( |
145 | | /// `<col1> = <col2>`) expressions, which are represented as a list of `Columns` |
146 | | /// in [`Self::on`]. |
147 | | /// |
148 | | /// Non-equality predicates, which can not pushed down to a join inputs (e.g. |
149 | | /// `<col1> != <col2>`) are known as "filter expressions" and are evaluated |
150 | | /// after the equijoin predicates. |
151 | | /// |
152 | | /// # "Build Side" vs "Probe Side" |
153 | | /// |
154 | | /// HashJoin takes two inputs, which are referred to as the "build" and the |
155 | | /// "probe". The build side is the first child, and the probe side is the second |
156 | | /// child. |
157 | | /// |
158 | | /// The two inputs are treated differently and it is VERY important that the |
159 | | /// *smaller* input is placed on the build side to minimize the work of creating |
160 | | /// the hash table. |
161 | | /// |
162 | | /// ```text |
163 | | /// ┌───────────┐ |
164 | | /// │ HashJoin │ |
165 | | /// │ │ |
166 | | /// └───────────┘ |
167 | | /// │ │ |
168 | | /// ┌─────┘ └─────┐ |
169 | | /// ▼ ▼ |
170 | | /// ┌────────────┐ ┌─────────────┐ |
171 | | /// │ Input │ │ Input │ |
172 | | /// │ [0] │ │ [1] │ |
173 | | /// └────────────┘ └─────────────┘ |
174 | | /// |
175 | | /// "build side" "probe side" |
176 | | /// ``` |
177 | | /// |
178 | | /// Execution proceeds in 2 stages: |
179 | | /// |
180 | | /// 1. the **build phase** creates a hash table from the tuples of the build side, |
181 | | /// and single concatenated batch containing data from all fetched record batches. |
182 | | /// Resulting hash table stores hashed join-key fields for each row as a key, and |
183 | | /// indices of corresponding rows in concatenated batch. |
184 | | /// |
185 | | /// Hash join uses LIFO data structure as a hash table, and in order to retain |
186 | | /// original build-side input order while obtaining data during probe phase, hash |
187 | | /// table is updated by iterating batch sequence in reverse order -- it allows to |
188 | | /// keep rows with smaller indices "on the top" of hash table, and still maintain |
189 | | /// correct indexing for concatenated build-side data batch. |
190 | | /// |
191 | | /// Example of build phase for 3 record batches: |
192 | | /// |
193 | | /// |
194 | | /// ```text |
195 | | /// |
196 | | /// Original build-side data Inserting build-side values into hashmap Concatenated build-side batch |
197 | | /// ┌───────────────────────────┐ |
198 | | /// hasmap.insert(row-hash, row-idx + offset) │ idx │ |
199 | | /// ┌───────┐ │ ┌───────┐ │ |
200 | | /// │ Row 1 │ 1) update_hash for batch 3 with offset 0 │ │ Row 6 │ 0 │ |
201 | | /// Batch 1 │ │ - hashmap.insert(Row 7, idx 1) │ Batch 3 │ │ │ |
202 | | /// │ Row 2 │ - hashmap.insert(Row 6, idx 0) │ │ Row 7 │ 1 │ |
203 | | /// └───────┘ │ └───────┘ │ |
204 | | /// │ │ |
205 | | /// ┌───────┐ │ ┌───────┐ │ |
206 | | /// │ Row 3 │ 2) update_hash for batch 2 with offset 2 │ │ Row 3 │ 2 │ |
207 | | /// │ │ - hashmap.insert(Row 5, idx 4) │ │ │ │ |
208 | | /// Batch 2 │ Row 4 │ - hashmap.insert(Row 4, idx 3) │ Batch 2 │ Row 4 │ 3 │ |
209 | | /// │ │ - hashmap.insert(Row 3, idx 2) │ │ │ │ |
210 | | /// │ Row 5 │ │ │ Row 5 │ 4 │ |
211 | | /// └───────┘ │ └───────┘ │ |
212 | | /// │ │ |
213 | | /// ┌───────┐ │ ┌───────┐ │ |
214 | | /// │ Row 6 │ 3) update_hash for batch 1 with offset 5 │ │ Row 1 │ 5 │ |
215 | | /// Batch 3 │ │ - hashmap.insert(Row 2, idx 5) │ Batch 1 │ │ │ |
216 | | /// │ Row 7 │ - hashmap.insert(Row 1, idx 6) │ │ Row 2 │ 6 │ |
217 | | /// └───────┘ │ └───────┘ │ |
218 | | /// │ │ |
219 | | /// └───────────────────────────┘ |
220 | | /// |
221 | | /// ``` |
222 | | /// |
223 | | /// 2. the **probe phase** where the tuples of the probe side are streamed |
224 | | /// through, checking for matches of the join keys in the hash table. |
225 | | /// |
226 | | /// ```text |
227 | | /// ┌────────────────┐ ┌────────────────┐ |
228 | | /// │ ┌─────────┐ │ │ ┌─────────┐ │ |
229 | | /// │ │ Hash │ │ │ │ Hash │ │ |
230 | | /// │ │ Table │ │ │ │ Table │ │ |
231 | | /// │ │(keys are│ │ │ │(keys are│ │ |
232 | | /// │ │equi join│ │ │ │equi join│ │ Stage 2: batches from |
233 | | /// Stage 1: the │ │columns) │ │ │ │columns) │ │ the probe side are |
234 | | /// *entire* build │ │ │ │ │ │ │ │ streamed through, and |
235 | | /// side is read │ └─────────┘ │ │ └─────────┘ │ checked against the |
236 | | /// into the hash │ ▲ │ │ ▲ │ contents of the hash |
237 | | /// table │ HashJoin │ │ HashJoin │ table |
238 | | /// └──────┼─────────┘ └──────────┼─────┘ |
239 | | /// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ |
240 | | /// │ │ |
241 | | /// |
242 | | /// │ │ |
243 | | /// ┌────────────┐ ┌────────────┐ |
244 | | /// │RecordBatch │ │RecordBatch │ |
245 | | /// └────────────┘ └────────────┘ |
246 | | /// ┌────────────┐ ┌────────────┐ |
247 | | /// │RecordBatch │ │RecordBatch │ |
248 | | /// └────────────┘ └────────────┘ |
249 | | /// ... ... |
250 | | /// ┌────────────┐ ┌────────────┐ |
251 | | /// │RecordBatch │ │RecordBatch │ |
252 | | /// └────────────┘ └────────────┘ |
253 | | /// |
254 | | /// build side probe side |
255 | | /// |
256 | | /// ``` |
257 | | /// |
258 | | /// # Example "Optimal" Plans |
259 | | /// |
260 | | /// The differences in the inputs means that for classic "Star Schema Query", |
261 | | /// the optimal plan will be a **"Right Deep Tree"** . A Star Schema Query is |
262 | | /// one where there is one large table and several smaller "dimension" tables, |
263 | | /// joined on `Foreign Key = Primary Key` predicates. |
264 | | /// |
265 | | /// A "Right Deep Tree" looks like this large table as the probe side on the |
266 | | /// lowest join: |
267 | | /// |
268 | | /// ```text |
269 | | /// ┌───────────┐ |
270 | | /// │ HashJoin │ |
271 | | /// │ │ |
272 | | /// └───────────┘ |
273 | | /// │ │ |
274 | | /// ┌───────┘ └──────────┐ |
275 | | /// ▼ ▼ |
276 | | /// ┌───────────────┐ ┌───────────┐ |
277 | | /// │ small table 1 │ │ HashJoin │ |
278 | | /// │ "dimension" │ │ │ |
279 | | /// └───────────────┘ └───┬───┬───┘ |
280 | | /// ┌──────────┘ └───────┐ |
281 | | /// │ │ |
282 | | /// ▼ ▼ |
283 | | /// ┌───────────────┐ ┌───────────┐ |
284 | | /// │ small table 2 │ │ HashJoin │ |
285 | | /// │ "dimension" │ │ │ |
286 | | /// └───────────────┘ └───┬───┬───┘ |
287 | | /// ┌────────┘ └────────┐ |
288 | | /// │ │ |
289 | | /// ▼ ▼ |
290 | | /// ┌───────────────┐ ┌───────────────┐ |
291 | | /// │ small table 3 │ │ large table │ |
292 | | /// │ "dimension" │ │ "fact" │ |
293 | | /// └───────────────┘ └───────────────┘ |
294 | | /// ``` |
295 | | #[derive(Debug)] |
296 | | pub struct HashJoinExec { |
297 | | /// left (build) side which gets hashed |
298 | | pub left: Arc<dyn ExecutionPlan>, |
299 | | /// right (probe) side which are filtered by the hash table |
300 | | pub right: Arc<dyn ExecutionPlan>, |
301 | | /// Set of equijoin columns from the relations: `(left_col, right_col)` |
302 | | pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>, |
303 | | /// Filters which are applied while finding matching rows |
304 | | pub filter: Option<JoinFilter>, |
305 | | /// How the join is performed (`OUTER`, `INNER`, etc) |
306 | | pub join_type: JoinType, |
307 | | /// The schema after join. Please be careful when using this schema, |
308 | | /// if there is a projection, the schema isn't the same as the output schema. |
309 | | join_schema: SchemaRef, |
310 | | /// Future that consumes left input and builds the hash table |
311 | | left_fut: OnceAsync<JoinLeftData>, |
312 | | /// Shared the `RandomState` for the hashing algorithm |
313 | | random_state: RandomState, |
314 | | /// Partitioning mode to use |
315 | | pub mode: PartitionMode, |
316 | | /// Execution metrics |
317 | | metrics: ExecutionPlanMetricsSet, |
318 | | /// The projection indices of the columns in the output schema of join |
319 | | pub projection: Option<Vec<usize>>, |
320 | | /// Information of index and left / right placement of columns |
321 | | column_indices: Vec<ColumnIndex>, |
322 | | /// Null matching behavior: If `null_equals_null` is true, rows that have |
323 | | /// `null`s in both left and right equijoin columns will be matched. |
324 | | /// Otherwise, rows that have `null`s in the join columns will not be |
325 | | /// matched and thus will not appear in the output. |
326 | | pub null_equals_null: bool, |
327 | | /// Cache holding plan properties like equivalences, output partitioning etc. |
328 | | cache: PlanProperties, |
329 | | } |
330 | | |
331 | | impl HashJoinExec { |
332 | | /// Tries to create a new [HashJoinExec]. |
333 | | /// |
334 | | /// # Error |
335 | | /// This function errors when it is not possible to join the left and right sides on keys `on`. |
336 | | #[allow(clippy::too_many_arguments)] |
337 | 683 | pub fn try_new( |
338 | 683 | left: Arc<dyn ExecutionPlan>, |
339 | 683 | right: Arc<dyn ExecutionPlan>, |
340 | 683 | on: JoinOn, |
341 | 683 | filter: Option<JoinFilter>, |
342 | 683 | join_type: &JoinType, |
343 | 683 | projection: Option<Vec<usize>>, |
344 | 683 | partition_mode: PartitionMode, |
345 | 683 | null_equals_null: bool, |
346 | 683 | ) -> Result<Self> { |
347 | 683 | let left_schema = left.schema(); |
348 | 683 | let right_schema = right.schema(); |
349 | 683 | if on.is_empty() { |
350 | 0 | return plan_err!("On constraints in HashJoinExec should be non-empty"); |
351 | 683 | } |
352 | 683 | |
353 | 683 | check_join_is_valid(&left_schema, &right_schema, &on)?0 ; |
354 | | |
355 | 683 | let (join_schema, column_indices) = |
356 | 683 | build_join_schema(&left_schema, &right_schema, join_type); |
357 | 683 | |
358 | 683 | let random_state = RandomState::with_seeds(0, 0, 0, 0); |
359 | 683 | |
360 | 683 | let join_schema = Arc::new(join_schema); |
361 | 683 | |
362 | 683 | // check if the projection is valid |
363 | 683 | can_project(&join_schema, projection.as_ref())?0 ; |
364 | | |
365 | 683 | let cache = Self::compute_properties( |
366 | 683 | &left, |
367 | 683 | &right, |
368 | 683 | Arc::clone(&join_schema), |
369 | 683 | *join_type, |
370 | 683 | &on, |
371 | 683 | partition_mode, |
372 | 683 | projection.as_ref(), |
373 | 683 | )?0 ; |
374 | | |
375 | 683 | Ok(HashJoinExec { |
376 | 683 | left, |
377 | 683 | right, |
378 | 683 | on, |
379 | 683 | filter, |
380 | 683 | join_type: *join_type, |
381 | 683 | join_schema, |
382 | 683 | left_fut: Default::default(), |
383 | 683 | random_state, |
384 | 683 | mode: partition_mode, |
385 | 683 | metrics: ExecutionPlanMetricsSet::new(), |
386 | 683 | projection, |
387 | 683 | column_indices, |
388 | 683 | null_equals_null, |
389 | 683 | cache, |
390 | 683 | }) |
391 | 683 | } |
392 | | |
393 | | /// left (build) side which gets hashed |
394 | 0 | pub fn left(&self) -> &Arc<dyn ExecutionPlan> { |
395 | 0 | &self.left |
396 | 0 | } |
397 | | |
398 | | /// right (probe) side which are filtered by the hash table |
399 | 327 | pub fn right(&self) -> &Arc<dyn ExecutionPlan> { |
400 | 327 | &self.right |
401 | 327 | } |
402 | | |
403 | | /// Set of common columns used to join on |
404 | 0 | pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { |
405 | 0 | &self.on |
406 | 0 | } |
407 | | |
408 | | /// Filters applied before join output |
409 | 0 | pub fn filter(&self) -> Option<&JoinFilter> { |
410 | 0 | self.filter.as_ref() |
411 | 0 | } |
412 | | |
413 | | /// How the join is performed |
414 | 0 | pub fn join_type(&self) -> &JoinType { |
415 | 0 | &self.join_type |
416 | 0 | } |
417 | | |
418 | | /// The partitioning mode of this hash join |
419 | 0 | pub fn partition_mode(&self) -> &PartitionMode { |
420 | 0 | &self.mode |
421 | 0 | } |
422 | | |
423 | | /// Get null_equals_null |
424 | 0 | pub fn null_equals_null(&self) -> bool { |
425 | 0 | self.null_equals_null |
426 | 0 | } |
427 | | |
428 | | /// Calculate order preservation flags for this hash join. |
429 | 683 | fn maintains_input_order(join_type: JoinType) -> Vec<bool> { |
430 | 683 | vec![ |
431 | | false, |
432 | 340 | matches!( |
433 | 683 | join_type, |
434 | | JoinType::Inner |
435 | | | JoinType::Right |
436 | | | JoinType::RightAnti |
437 | | | JoinType::RightSemi |
438 | | ), |
439 | | ] |
440 | 683 | } |
441 | | |
442 | | /// Get probe side information for the hash join. |
443 | 683 | pub fn probe_side() -> JoinSide { |
444 | 683 | // In current implementation right side is always probe side. |
445 | 683 | JoinSide::Right |
446 | 683 | } |
447 | | |
448 | | /// Return whether the join contains a projection |
449 | 0 | pub fn contain_projection(&self) -> bool { |
450 | 0 | self.projection.is_some() |
451 | 0 | } |
452 | | |
453 | | /// Return new instance of [HashJoinExec] with the given projection. |
454 | 0 | pub fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> { |
455 | 0 | // check if the projection is valid |
456 | 0 | can_project(&self.schema(), projection.as_ref())?; |
457 | 0 | let projection = match projection { |
458 | 0 | Some(projection) => match &self.projection { |
459 | 0 | Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), |
460 | 0 | None => Some(projection), |
461 | | }, |
462 | 0 | None => None, |
463 | | }; |
464 | 0 | Self::try_new( |
465 | 0 | Arc::clone(&self.left), |
466 | 0 | Arc::clone(&self.right), |
467 | 0 | self.on.clone(), |
468 | 0 | self.filter.clone(), |
469 | 0 | &self.join_type, |
470 | 0 | projection, |
471 | 0 | self.mode, |
472 | 0 | self.null_equals_null, |
473 | 0 | ) |
474 | 0 | } |
475 | | |
476 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
477 | 683 | fn compute_properties( |
478 | 683 | left: &Arc<dyn ExecutionPlan>, |
479 | 683 | right: &Arc<dyn ExecutionPlan>, |
480 | 683 | schema: SchemaRef, |
481 | 683 | join_type: JoinType, |
482 | 683 | on: JoinOnRef, |
483 | 683 | mode: PartitionMode, |
484 | 683 | projection: Option<&Vec<usize>>, |
485 | 683 | ) -> Result<PlanProperties> { |
486 | 683 | // Calculate equivalence properties: |
487 | 683 | let mut eq_properties = join_equivalence_properties( |
488 | 683 | left.equivalence_properties().clone(), |
489 | 683 | right.equivalence_properties().clone(), |
490 | 683 | &join_type, |
491 | 683 | Arc::clone(&schema), |
492 | 683 | &Self::maintains_input_order(join_type), |
493 | 683 | Some(Self::probe_side()), |
494 | 683 | on, |
495 | 683 | ); |
496 | | |
497 | 683 | let mut output_partitioning = match mode { |
498 | | PartitionMode::CollectLeft => { |
499 | 327 | asymmetric_join_output_partitioning(left, right, &join_type) |
500 | | } |
501 | 0 | PartitionMode::Auto => Partitioning::UnknownPartitioning( |
502 | 0 | right.output_partitioning().partition_count(), |
503 | 0 | ), |
504 | | PartitionMode::Partitioned => { |
505 | 356 | symmetric_join_output_partitioning(left, right, &join_type) |
506 | | } |
507 | | }; |
508 | | |
509 | | // Determine execution mode by checking whether this join is pipeline |
510 | | // breaking. This happens when the left side is unbounded, or the right |
511 | | // side is unbounded with `Left`, `Full`, `LeftAnti` or `LeftSemi` join types. |
512 | 683 | let pipeline_breaking = left.execution_mode().is_unbounded() |
513 | 683 | || (right.execution_mode().is_unbounded() |
514 | 0 | && matches!( |
515 | 0 | join_type, |
516 | | JoinType::Left |
517 | | | JoinType::Full |
518 | | | JoinType::LeftAnti |
519 | | | JoinType::LeftSemi |
520 | | )); |
521 | | |
522 | 683 | let mode = if pipeline_breaking { |
523 | 0 | ExecutionMode::PipelineBreaking |
524 | | } else { |
525 | 683 | execution_mode_from_children([left, right]) |
526 | | }; |
527 | | |
528 | | // If contains projection, update the PlanProperties. |
529 | 683 | if let Some(projection0 ) = projection { |
530 | | // construct a map from the input expressions to the output expression of the Projection |
531 | 0 | let projection_mapping = |
532 | 0 | ProjectionMapping::from_indices(projection, &schema)?; |
533 | 0 | let out_schema = project_schema(&schema, Some(projection))?; |
534 | 0 | output_partitioning = |
535 | 0 | output_partitioning.project(&projection_mapping, &eq_properties); |
536 | 0 | eq_properties = eq_properties.project(&projection_mapping, out_schema); |
537 | 683 | } |
538 | 683 | Ok(PlanProperties::new( |
539 | 683 | eq_properties, |
540 | 683 | output_partitioning, |
541 | 683 | mode, |
542 | 683 | )) |
543 | 683 | } |
544 | | } |
545 | | |
546 | | impl DisplayAs for HashJoinExec { |
547 | 0 | fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { |
548 | 0 | match t { |
549 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
550 | 0 | let display_filter = self.filter.as_ref().map_or_else( |
551 | 0 | || "".to_string(), |
552 | 0 | |f| format!(", filter={}", f.expression()), |
553 | 0 | ); |
554 | 0 | let display_projections = if self.contain_projection() { |
555 | 0 | format!( |
556 | 0 | ", projection=[{}]", |
557 | 0 | self.projection |
558 | 0 | .as_ref() |
559 | 0 | .unwrap() |
560 | 0 | .iter() |
561 | 0 | .map(|index| format!( |
562 | 0 | "{}@{}", |
563 | 0 | self.join_schema.fields().get(*index).unwrap().name(), |
564 | 0 | index |
565 | 0 | )) |
566 | 0 | .collect::<Vec<_>>() |
567 | 0 | .join(", ") |
568 | 0 | ) |
569 | | } else { |
570 | 0 | "".to_string() |
571 | | }; |
572 | 0 | let on = self |
573 | 0 | .on |
574 | 0 | .iter() |
575 | 0 | .map(|(c1, c2)| format!("({}, {})", c1, c2)) |
576 | 0 | .collect::<Vec<String>>() |
577 | 0 | .join(", "); |
578 | 0 | write!( |
579 | 0 | f, |
580 | 0 | "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}", |
581 | 0 | self.mode, self.join_type, on, display_filter, display_projections |
582 | 0 | ) |
583 | 0 | } |
584 | 0 | } |
585 | 0 | } |
586 | | } |
587 | | |
588 | | impl ExecutionPlan for HashJoinExec { |
589 | 0 | fn name(&self) -> &'static str { |
590 | 0 | "HashJoinExec" |
591 | 0 | } |
592 | | |
593 | 0 | fn as_any(&self) -> &dyn Any { |
594 | 0 | self |
595 | 0 | } |
596 | | |
597 | 1.91k | fn properties(&self) -> &PlanProperties { |
598 | 1.91k | &self.cache |
599 | 1.91k | } |
600 | | |
601 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
602 | 0 | match self.mode { |
603 | 0 | PartitionMode::CollectLeft => vec![ |
604 | 0 | Distribution::SinglePartition, |
605 | 0 | Distribution::UnspecifiedDistribution, |
606 | 0 | ], |
607 | | PartitionMode::Partitioned => { |
608 | 0 | let (left_expr, right_expr) = self |
609 | 0 | .on |
610 | 0 | .iter() |
611 | 0 | .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) |
612 | 0 | .unzip(); |
613 | 0 | vec![ |
614 | 0 | Distribution::HashPartitioned(left_expr), |
615 | 0 | Distribution::HashPartitioned(right_expr), |
616 | 0 | ] |
617 | | } |
618 | 0 | PartitionMode::Auto => vec![ |
619 | 0 | Distribution::UnspecifiedDistribution, |
620 | 0 | Distribution::UnspecifiedDistribution, |
621 | 0 | ], |
622 | | } |
623 | 0 | } |
624 | | |
625 | | // For [JoinType::Inner] and [JoinType::RightSemi] in hash joins, the probe phase initiates by |
626 | | // applying the hash function to convert the join key(s) in each row into a hash value from the |
627 | | // probe side table in the order they're arranged. The hash value is used to look up corresponding |
628 | | // entries in the hash table that was constructed from the build side table during the build phase. |
629 | | // |
630 | | // Because of the immediate generation of result rows once a match is found, |
631 | | // the output of the join tends to follow the order in which the rows were read from |
632 | | // the probe side table. This is simply due to the sequence in which the rows were processed. |
633 | | // Hence, it appears that the hash join is preserving the order of the probe side. |
634 | | // |
635 | | // Meanwhile, in the case of a [JoinType::RightAnti] hash join, |
636 | | // the unmatched rows from the probe side are also kept in order. |
637 | | // This is because the **`RightAnti`** join is designed to return rows from the right |
638 | | // (probe side) table that have no match in the left (build side) table. Because the rows |
639 | | // are processed sequentially in the probe phase, and unmatched rows are directly output |
640 | | // as results, these results tend to retain the order of the probe side table. |
641 | 0 | fn maintains_input_order(&self) -> Vec<bool> { |
642 | 0 | Self::maintains_input_order(self.join_type) |
643 | 0 | } |
644 | | |
645 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
646 | 0 | vec![&self.left, &self.right] |
647 | 0 | } |
648 | | |
649 | 0 | fn with_new_children( |
650 | 0 | self: Arc<Self>, |
651 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
652 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
653 | 0 | Ok(Arc::new(HashJoinExec::try_new( |
654 | 0 | Arc::clone(&children[0]), |
655 | 0 | Arc::clone(&children[1]), |
656 | 0 | self.on.clone(), |
657 | 0 | self.filter.clone(), |
658 | 0 | &self.join_type, |
659 | 0 | self.projection.clone(), |
660 | 0 | self.mode, |
661 | 0 | self.null_equals_null, |
662 | 0 | )?)) |
663 | 0 | } |
664 | | |
665 | 1.75k | fn execute( |
666 | 1.75k | &self, |
667 | 1.75k | partition: usize, |
668 | 1.75k | context: Arc<TaskContext>, |
669 | 1.75k | ) -> Result<SendableRecordBatchStream> { |
670 | 1.75k | let on_left = self |
671 | 1.75k | .on |
672 | 1.75k | .iter() |
673 | 1.76k | .map(|on| Arc::clone(&on.0)) |
674 | 1.75k | .collect::<Vec<_>>(); |
675 | 1.75k | let on_right = self |
676 | 1.75k | .on |
677 | 1.75k | .iter() |
678 | 1.76k | .map(|on| Arc::clone(&on.1)) |
679 | 1.75k | .collect::<Vec<_>>(); |
680 | 1.75k | let left_partitions = self.left.output_partitioning().partition_count(); |
681 | 1.75k | let right_partitions = self.right.output_partitioning().partition_count(); |
682 | 1.75k | |
683 | 1.75k | if self.mode == PartitionMode::Partitioned && left_partitions != right_partitions1.40k |
684 | | { |
685 | 0 | return internal_err!( |
686 | 0 | "Invalid HashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ |
687 | 0 | consider using RepartitionExec" |
688 | 0 | ); |
689 | 1.75k | } |
690 | 1.75k | |
691 | 1.75k | let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); |
692 | 1.75k | let left_fut = match self.mode { |
693 | 356 | PartitionMode::CollectLeft => self.left_fut.once(|| { |
694 | 327 | let reservation = |
695 | 327 | MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); |
696 | 327 | collect_left_input( |
697 | 327 | None, |
698 | 327 | self.random_state.clone(), |
699 | 327 | Arc::clone(&self.left), |
700 | 327 | on_left.clone(), |
701 | 327 | Arc::clone(&context), |
702 | 327 | join_metrics.clone(), |
703 | 327 | reservation, |
704 | 327 | need_produce_result_in_final(self.join_type), |
705 | 327 | self.right().output_partitioning().partition_count(), |
706 | 327 | ) |
707 | 356 | }), |
708 | | PartitionMode::Partitioned => { |
709 | 1.40k | let reservation = |
710 | 1.40k | MemoryConsumer::new(format!("HashJoinInput[{partition}]")) |
711 | 1.40k | .register(context.memory_pool()); |
712 | 1.40k | |
713 | 1.40k | OnceFut::new(collect_left_input( |
714 | 1.40k | Some(partition), |
715 | 1.40k | self.random_state.clone(), |
716 | 1.40k | Arc::clone(&self.left), |
717 | 1.40k | on_left.clone(), |
718 | 1.40k | Arc::clone(&context), |
719 | 1.40k | join_metrics.clone(), |
720 | 1.40k | reservation, |
721 | 1.40k | need_produce_result_in_final(self.join_type), |
722 | 1.40k | 1, |
723 | 1.40k | )) |
724 | | } |
725 | | PartitionMode::Auto => { |
726 | 0 | return plan_err!( |
727 | 0 | "Invalid HashJoinExec, unsupported PartitionMode {:?} in execute()", |
728 | 0 | PartitionMode::Auto |
729 | 0 | ); |
730 | | } |
731 | | }; |
732 | | |
733 | 1.75k | let batch_size = context.session_config().batch_size(); |
734 | | |
735 | | // we have the batches and the hash map with their keys. We can how create a stream |
736 | | // over the right that uses this information to issue new batches. |
737 | 1.75k | let right_stream = self.right.execute(partition, context)?0 ; |
738 | | |
739 | | // update column indices to reflect the projection |
740 | 1.75k | let column_indices_after_projection = match &self.projection { |
741 | 0 | Some(projection) => projection |
742 | 0 | .iter() |
743 | 0 | .map(|i| self.column_indices[*i].clone()) |
744 | 0 | .collect(), |
745 | 1.75k | None => self.column_indices.clone(), |
746 | | }; |
747 | | |
748 | 1.75k | Ok(Box::pin(HashJoinStream { |
749 | 1.75k | schema: self.schema(), |
750 | 1.75k | on_left, |
751 | 1.75k | on_right, |
752 | 1.75k | filter: self.filter.clone(), |
753 | 1.75k | join_type: self.join_type, |
754 | 1.75k | right: right_stream, |
755 | 1.75k | column_indices: column_indices_after_projection, |
756 | 1.75k | random_state: self.random_state.clone(), |
757 | 1.75k | join_metrics, |
758 | 1.75k | null_equals_null: self.null_equals_null, |
759 | 1.75k | state: HashJoinStreamState::WaitBuildSide, |
760 | 1.75k | build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), |
761 | 1.75k | batch_size, |
762 | 1.75k | hashes_buffer: vec![], |
763 | 1.75k | right_side_ordered: self.right.output_ordering().is_some(), |
764 | 1.75k | })) |
765 | 1.75k | } |
766 | | |
767 | 0 | fn metrics(&self) -> Option<MetricsSet> { |
768 | 0 | Some(self.metrics.clone_inner()) |
769 | 0 | } |
770 | | |
771 | 0 | fn statistics(&self) -> Result<Statistics> { |
772 | | // TODO stats: it is not possible in general to know the output size of joins |
773 | | // There are some special cases though, for example: |
774 | | // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` |
775 | 0 | let mut stats = estimate_join_statistics( |
776 | 0 | Arc::clone(&self.left), |
777 | 0 | Arc::clone(&self.right), |
778 | 0 | self.on.clone(), |
779 | 0 | &self.join_type, |
780 | 0 | &self.join_schema, |
781 | 0 | )?; |
782 | | // Project statistics if there is a projection |
783 | 0 | if let Some(projection) = &self.projection { |
784 | 0 | stats.column_statistics = stats |
785 | 0 | .column_statistics |
786 | 0 | .into_iter() |
787 | 0 | .enumerate() |
788 | 0 | .filter(|(i, _)| projection.contains(i)) |
789 | 0 | .map(|(_, s)| s) |
790 | 0 | .collect(); |
791 | 0 | } |
792 | 0 | Ok(stats) |
793 | 0 | } |
794 | | } |
795 | | |
796 | | /// Reads the left (build) side of the input, buffering it in memory, to build a |
797 | | /// hash table (`LeftJoinData`) |
798 | | #[allow(clippy::too_many_arguments)] |
799 | 1.72k | async fn collect_left_input( |
800 | 1.72k | partition: Option<usize>, |
801 | 1.72k | random_state: RandomState, |
802 | 1.72k | left: Arc<dyn ExecutionPlan>, |
803 | 1.72k | on_left: Vec<PhysicalExprRef>, |
804 | 1.72k | context: Arc<TaskContext>, |
805 | 1.72k | metrics: BuildProbeJoinMetrics, |
806 | 1.72k | reservation: MemoryReservation, |
807 | 1.72k | with_visited_indices_bitmap: bool, |
808 | 1.72k | probe_threads_count: usize, |
809 | 1.72k | ) -> Result<JoinLeftData> { |
810 | 1.72k | let schema = left.schema(); |
811 | | |
812 | 1.72k | let (left_input, left_input_partition) = if let Some(partition1.40k ) = partition { |
813 | 1.40k | (left, partition) |
814 | 327 | } else if left.output_partitioning().partition_count() != 1 { |
815 | 6 | (Arc::new(CoalescePartitionsExec::new(left)) as _, 0) |
816 | | } else { |
817 | 321 | (left, 0) |
818 | | }; |
819 | | |
820 | | // Depending on partition argument load single partition or whole left side in memory |
821 | 1.72k | let stream = left_input.execute(left_input_partition, Arc::clone(&context))?0 ; |
822 | | |
823 | | // This operation performs 2 steps at once: |
824 | | // 1. creates a [JoinHashMap] of all batches from the stream |
825 | | // 2. stores the batches in a vector. |
826 | 1.72k | let initial = (Vec::new(), 0, metrics, reservation); |
827 | 1.72k | let (batches, num_rows, metrics, mut reservation1.71k ) = stream |
828 | 4.31k | .try_fold(initial, |mut acc, batch| async { |
829 | 4.31k | let batch_size = batch.get_array_memory_size(); |
830 | 4.31k | // Reserve memory for incoming batch |
831 | 4.31k | acc.3.try_grow(batch_size)?16 ; |
832 | | // Update metrics |
833 | 4.30k | acc.2.build_mem_used.add(batch_size); |
834 | 4.30k | acc.2.build_input_batches.add(1); |
835 | 4.30k | acc.2.build_input_rows.add(batch.num_rows()); |
836 | 4.30k | // Update rowcount |
837 | 4.30k | acc.1 += batch.num_rows(); |
838 | 4.30k | // Push batch to output |
839 | 4.30k | acc.0.push(batch); |
840 | 4.30k | Ok(acc) |
841 | 8.63k | })1.72k |
842 | 1.49k | .await?16 ; |
843 | | |
844 | | // Estimation of memory size, required for hashtable, prior to allocation. |
845 | | // Final result can be verified using `RawTable.allocation_info()` |
846 | 1.71k | let fixed_size = std::mem::size_of::<JoinHashMap>(); |
847 | 1.71k | let estimated_hashtable_size = |
848 | 1.71k | estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?0 ; |
849 | | |
850 | 1.71k | reservation.try_grow(estimated_hashtable_size)?0 ; |
851 | 1.71k | metrics.build_mem_used.add(estimated_hashtable_size); |
852 | 1.71k | |
853 | 1.71k | let mut hashmap = JoinHashMap::with_capacity(num_rows); |
854 | 1.71k | let mut hashes_buffer = Vec::new(); |
855 | 1.71k | let mut offset = 0; |
856 | 1.71k | |
857 | 1.71k | // Updating hashmap starting from the last batch |
858 | 1.71k | let batches_iter = batches.iter().rev(); |
859 | 4.30k | for batch in batches_iter.clone()1.71k { |
860 | 4.30k | hashes_buffer.clear(); |
861 | 4.30k | hashes_buffer.resize(batch.num_rows(), 0); |
862 | 4.30k | update_hash( |
863 | 4.30k | &on_left, |
864 | 4.30k | batch, |
865 | 4.30k | &mut hashmap, |
866 | 4.30k | offset, |
867 | 4.30k | &random_state, |
868 | 4.30k | &mut hashes_buffer, |
869 | 4.30k | 0, |
870 | 4.30k | true, |
871 | 4.30k | )?0 ; |
872 | 4.30k | offset += batch.num_rows(); |
873 | | } |
874 | | // Merge all batches into a single batch, so we can directly index into the arrays |
875 | 1.71k | let single_batch = concat_batches(&schema, batches_iter)?0 ; |
876 | | |
877 | | // Reserve additional memory for visited indices bitmap and create shared builder |
878 | 1.71k | let visited_indices_bitmap = if with_visited_indices_bitmap { |
879 | 854 | let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); |
880 | 854 | reservation.try_grow(bitmap_size)?0 ; |
881 | 854 | metrics.build_mem_used.add(bitmap_size); |
882 | 854 | |
883 | 854 | let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); |
884 | 854 | bitmap_buffer.append_n(num_rows, false); |
885 | 854 | bitmap_buffer |
886 | | } else { |
887 | 857 | BooleanBufferBuilder::new(0) |
888 | | }; |
889 | | |
890 | 1.71k | let data = JoinLeftData::new( |
891 | 1.71k | hashmap, |
892 | 1.71k | single_batch, |
893 | 1.71k | Mutex::new(visited_indices_bitmap), |
894 | 1.71k | AtomicUsize::new(probe_threads_count), |
895 | 1.71k | reservation, |
896 | 1.71k | ); |
897 | 1.71k | |
898 | 1.71k | Ok(data) |
899 | 1.72k | } |
900 | | |
901 | | /// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` |
902 | | /// using `offset` as a start value for `batch` row indices. |
903 | | /// |
904 | | /// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, |
905 | | /// which allows to keep either first (if set to true) or last (if set to false) row index |
906 | | /// as a chain head for rows with equal hash values. |
907 | | #[allow(clippy::too_many_arguments)] |
908 | 12.4k | pub fn update_hash<T>( |
909 | 12.4k | on: &[PhysicalExprRef], |
910 | 12.4k | batch: &RecordBatch, |
911 | 12.4k | hash_map: &mut T, |
912 | 12.4k | offset: usize, |
913 | 12.4k | random_state: &RandomState, |
914 | 12.4k | hashes_buffer: &mut Vec<u64>, |
915 | 12.4k | deleted_offset: usize, |
916 | 12.4k | fifo_hashmap: bool, |
917 | 12.4k | ) -> Result<()> |
918 | 12.4k | where |
919 | 12.4k | T: JoinHashMapType, |
920 | 12.4k | { |
921 | | // evaluate the keys |
922 | 12.4k | let keys_values = on |
923 | 12.4k | .iter() |
924 | 12.4k | .map(|c| c.evaluate(batch)?0 .into_array(batch.num_rows())) |
925 | 12.4k | .collect::<Result<Vec<_>>>()?0 ; |
926 | | |
927 | | // calculate the hash values |
928 | 12.4k | let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?0 ; |
929 | | |
930 | | // For usual JoinHashmap, the implementation is void. |
931 | 12.4k | hash_map.extend_zero(batch.num_rows()); |
932 | 12.4k | |
933 | 12.4k | // Updating JoinHashMap from hash values iterator |
934 | 12.4k | let hash_values_iter = hash_values |
935 | 12.4k | .iter() |
936 | 12.4k | .enumerate() |
937 | 31.3k | .map(|(i, val)| (i + offset, val)); |
938 | 12.4k | |
939 | 12.4k | if fifo_hashmap { |
940 | 4.30k | hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset); |
941 | 8.11k | } else { |
942 | 8.11k | hash_map.update_from_iter(hash_values_iter, deleted_offset); |
943 | 8.11k | } |
944 | | |
945 | 12.4k | Ok(()) |
946 | 12.4k | } |
947 | | |
948 | | /// Represents build-side of hash join. |
949 | | enum BuildSide { |
950 | | /// Indicates that build-side not collected yet |
951 | | Initial(BuildSideInitialState), |
952 | | /// Indicates that build-side data has been collected |
953 | | Ready(BuildSideReadyState), |
954 | | } |
955 | | |
956 | | /// Container for BuildSide::Initial related data |
957 | | struct BuildSideInitialState { |
958 | | /// Future for building hash table from build-side input |
959 | | left_fut: OnceFut<JoinLeftData>, |
960 | | } |
961 | | |
962 | | /// Container for BuildSide::Ready related data |
963 | | struct BuildSideReadyState { |
964 | | /// Collected build-side data |
965 | | left_data: Arc<JoinLeftData>, |
966 | | } |
967 | | |
968 | | impl BuildSide { |
969 | | /// Tries to extract BuildSideInitialState from BuildSide enum. |
970 | | /// Returns an error if state is not Initial. |
971 | 3.25k | fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { |
972 | 3.25k | match self { |
973 | 3.25k | BuildSide::Initial(state) => Ok(state), |
974 | 0 | _ => internal_err!("Expected build side in initial state"), |
975 | | } |
976 | 3.25k | } |
977 | | |
978 | | /// Tries to extract BuildSideReadyState from BuildSide enum. |
979 | | /// Returns an error if state is not Ready. |
980 | 862 | fn try_as_ready(&self) -> Result<&BuildSideReadyState> { |
981 | 862 | match self { |
982 | 862 | BuildSide::Ready(state) => Ok(state), |
983 | 0 | _ => internal_err!("Expected build side in ready state"), |
984 | | } |
985 | 862 | } |
986 | | |
987 | | /// Tries to extract BuildSideReadyState from BuildSide enum. |
988 | | /// Returns an error if state is not Ready. |
989 | 5.16k | fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { |
990 | 5.16k | match self { |
991 | 5.16k | BuildSide::Ready(state) => Ok(state), |
992 | 0 | _ => internal_err!("Expected build side in ready state"), |
993 | | } |
994 | 5.16k | } |
995 | | } |
996 | | |
997 | | /// Represents state of HashJoinStream |
998 | | /// |
999 | | /// Expected state transitions performed by HashJoinStream are: |
1000 | | /// |
1001 | | /// ```text |
1002 | | /// |
1003 | | /// WaitBuildSide |
1004 | | /// │ |
1005 | | /// ▼ |
1006 | | /// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed |
1007 | | /// │ │ |
1008 | | /// │ ▼ |
1009 | | /// └─ ProcessProbeBatch |
1010 | | /// |
1011 | | /// ``` |
1012 | | enum HashJoinStreamState { |
1013 | | /// Initial state for HashJoinStream indicating that build-side data not collected yet |
1014 | | WaitBuildSide, |
1015 | | /// Indicates that build-side has been collected, and stream is ready for fetching probe-side |
1016 | | FetchProbeBatch, |
1017 | | /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed |
1018 | | ProcessProbeBatch(ProcessProbeBatchState), |
1019 | | /// Indicates that probe-side has been fully processed |
1020 | | ExhaustedProbeSide, |
1021 | | /// Indicates that HashJoinStream execution is completed |
1022 | | Completed, |
1023 | | } |
1024 | | |
1025 | | impl HashJoinStreamState { |
1026 | | /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. |
1027 | | /// Returns an error if state is not ProcessProbeBatchState. |
1028 | 5.16k | fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { |
1029 | 5.16k | match self { |
1030 | 5.16k | HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), |
1031 | 0 | _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), |
1032 | | } |
1033 | 5.16k | } |
1034 | | } |
1035 | | |
1036 | | /// Container for HashJoinStreamState::ProcessProbeBatch related data |
1037 | | struct ProcessProbeBatchState { |
1038 | | /// Current probe-side batch |
1039 | | batch: RecordBatch, |
1040 | | /// Starting offset for JoinHashMap lookups |
1041 | | offset: JoinHashMapOffset, |
1042 | | /// Max joined probe-side index from current batch |
1043 | | joined_probe_idx: Option<usize>, |
1044 | | } |
1045 | | |
1046 | | impl ProcessProbeBatchState { |
1047 | 615 | fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option<usize>) { |
1048 | 615 | self.offset = offset; |
1049 | 615 | if joined_probe_idx.is_some() { |
1050 | 586 | self.joined_probe_idx = joined_probe_idx; |
1051 | 586 | }29 |
1052 | 615 | } |
1053 | | } |
1054 | | |
1055 | | /// [`Stream`] for [`HashJoinExec`] that does the actual join. |
1056 | | /// |
1057 | | /// This stream: |
1058 | | /// |
1059 | | /// 1. Reads the entire left input (build) and constructs a hash table |
1060 | | /// |
1061 | | /// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins |
1062 | | /// them with the contents of the hash table |
1063 | | struct HashJoinStream { |
1064 | | /// Input schema |
1065 | | schema: Arc<Schema>, |
1066 | | /// equijoin columns from the left (build side) |
1067 | | on_left: Vec<PhysicalExprRef>, |
1068 | | /// equijoin columns from the right (probe side) |
1069 | | on_right: Vec<PhysicalExprRef>, |
1070 | | /// optional join filter |
1071 | | filter: Option<JoinFilter>, |
1072 | | /// type of the join (left, right, semi, etc) |
1073 | | join_type: JoinType, |
1074 | | /// right (probe) input |
1075 | | right: SendableRecordBatchStream, |
1076 | | /// Random state used for hashing initialization |
1077 | | random_state: RandomState, |
1078 | | /// Metrics |
1079 | | join_metrics: BuildProbeJoinMetrics, |
1080 | | /// Information of index and left / right placement of columns |
1081 | | column_indices: Vec<ColumnIndex>, |
1082 | | /// If null_equals_null is true, null == null else null != null |
1083 | | null_equals_null: bool, |
1084 | | /// State of the stream |
1085 | | state: HashJoinStreamState, |
1086 | | /// Build side |
1087 | | build_side: BuildSide, |
1088 | | /// Maximum output batch size |
1089 | | batch_size: usize, |
1090 | | /// Scratch space for computing hashes |
1091 | | hashes_buffer: Vec<u64>, |
1092 | | /// Specifies whether the right side has an ordering to potentially preserve |
1093 | | right_side_ordered: bool, |
1094 | | } |
1095 | | |
1096 | | impl RecordBatchStream for HashJoinStream { |
1097 | 0 | fn schema(&self) -> SchemaRef { |
1098 | 0 | Arc::clone(&self.schema) |
1099 | 0 | } |
1100 | | } |
1101 | | |
1102 | | /// Executes lookups by hash against JoinHashMap and resolves potential |
1103 | | /// hash collisions. |
1104 | | /// Returns build/probe indices satisfying the equality condition, along with |
1105 | | /// (optional) starting point for next iteration. |
1106 | | /// |
1107 | | /// # Example |
1108 | | /// |
1109 | | /// For `LEFT.b1 = RIGHT.b2`: |
1110 | | /// LEFT (build) Table: |
1111 | | /// ```text |
1112 | | /// a1 b1 c1 |
1113 | | /// 1 1 10 |
1114 | | /// 3 3 30 |
1115 | | /// 5 5 50 |
1116 | | /// 7 7 70 |
1117 | | /// 9 8 90 |
1118 | | /// 11 8 110 |
1119 | | /// 13 10 130 |
1120 | | /// ``` |
1121 | | /// |
1122 | | /// RIGHT (probe) Table: |
1123 | | /// ```text |
1124 | | /// a2 b2 c2 |
1125 | | /// 2 2 20 |
1126 | | /// 4 4 40 |
1127 | | /// 6 6 60 |
1128 | | /// 8 8 80 |
1129 | | /// 10 10 100 |
1130 | | /// 12 10 120 |
1131 | | /// ``` |
1132 | | /// |
1133 | | /// The result is |
1134 | | /// ```text |
1135 | | /// "+----+----+-----+----+----+-----+", |
1136 | | /// "| a1 | b1 | c1 | a2 | b2 | c2 |", |
1137 | | /// "+----+----+-----+----+----+-----+", |
1138 | | /// "| 9 | 8 | 90 | 8 | 8 | 80 |", |
1139 | | /// "| 11 | 8 | 110 | 8 | 8 | 80 |", |
1140 | | /// "| 13 | 10 | 130 | 10 | 10 | 100 |", |
1141 | | /// "| 13 | 10 | 130 | 12 | 10 | 120 |", |
1142 | | /// "+----+----+-----+----+----+-----+" |
1143 | | /// ``` |
1144 | | /// |
1145 | | /// And the result of build and probe indices are: |
1146 | | /// ```text |
1147 | | /// Build indices: 4, 5, 6, 6 |
1148 | | /// Probe indices: 3, 3, 4, 5 |
1149 | | /// ``` |
1150 | | #[allow(clippy::too_many_arguments)] |
1151 | 5.16k | fn lookup_join_hashmap( |
1152 | 5.16k | build_hashmap: &JoinHashMap, |
1153 | 5.16k | build_input_buffer: &RecordBatch, |
1154 | 5.16k | probe_batch: &RecordBatch, |
1155 | 5.16k | build_on: &[PhysicalExprRef], |
1156 | 5.16k | probe_on: &[PhysicalExprRef], |
1157 | 5.16k | null_equals_null: bool, |
1158 | 5.16k | hashes_buffer: &[u64], |
1159 | 5.16k | limit: usize, |
1160 | 5.16k | offset: JoinHashMapOffset, |
1161 | 5.16k | ) -> Result<(UInt64Array, UInt32Array, Option<JoinHashMapOffset>)> { |
1162 | 5.16k | let keys_values = probe_on |
1163 | 5.16k | .iter() |
1164 | 5.18k | .map(|c| c.evaluate(probe_batch)?0 .into_array(probe_batch.num_rows())) |
1165 | 5.16k | .collect::<Result<Vec<_>>>()?0 ; |
1166 | 5.16k | let build_join_values = build_on |
1167 | 5.16k | .iter() |
1168 | 5.18k | .map(|c| { |
1169 | 5.18k | c.evaluate(build_input_buffer)?0 |
1170 | 5.18k | .into_array(build_input_buffer.num_rows()) |
1171 | 5.18k | }) |
1172 | 5.16k | .collect::<Result<Vec<_>>>()?0 ; |
1173 | | |
1174 | 5.16k | let (probe_indices, build_indices, next_offset) = build_hashmap |
1175 | 5.16k | .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); |
1176 | 5.16k | |
1177 | 5.16k | let build_indices: UInt64Array = build_indices.into(); |
1178 | 5.16k | let probe_indices: UInt32Array = probe_indices.into(); |
1179 | | |
1180 | 5.16k | let (build_indices, probe_indices) = equal_rows_arr( |
1181 | 5.16k | &build_indices, |
1182 | 5.16k | &probe_indices, |
1183 | 5.16k | &build_join_values, |
1184 | 5.16k | &keys_values, |
1185 | 5.16k | null_equals_null, |
1186 | 5.16k | )?0 ; |
1187 | | |
1188 | 5.16k | Ok((build_indices, probe_indices, next_offset)) |
1189 | 5.16k | } |
1190 | | |
1191 | | // version of eq_dyn supporting equality on null arrays |
1192 | 12.0k | fn eq_dyn_null( |
1193 | 12.0k | left: &dyn Array, |
1194 | 12.0k | right: &dyn Array, |
1195 | 12.0k | null_equals_null: bool, |
1196 | 12.0k | ) -> Result<BooleanArray, ArrowError> { |
1197 | 12.0k | // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special |
1198 | 12.0k | // implementation |
1199 | 12.0k | // <https://github.com/apache/datafusion/issues/10749> |
1200 | 12.0k | if left.data_type().is_nested() { |
1201 | 3 | let op = if null_equals_null { |
1202 | 1 | Operator::IsNotDistinctFrom |
1203 | | } else { |
1204 | 2 | Operator::Eq |
1205 | | }; |
1206 | 3 | return Ok(compare_op_for_nested(op, &left, &right)?0 ); |
1207 | 12.0k | } |
1208 | 12.0k | match (left.data_type(), right.data_type()) { |
1209 | 12.0k | _ if null_equals_null => not_distinct(&left, &right)0 , |
1210 | 12.0k | _ => eq(&left, &right), |
1211 | | } |
1212 | 12.0k | } |
1213 | | |
1214 | 12.0k | pub fn equal_rows_arr( |
1215 | 12.0k | indices_left: &UInt64Array, |
1216 | 12.0k | indices_right: &UInt32Array, |
1217 | 12.0k | left_arrays: &[ArrayRef], |
1218 | 12.0k | right_arrays: &[ArrayRef], |
1219 | 12.0k | null_equals_null: bool, |
1220 | 12.0k | ) -> Result<(UInt64Array, UInt32Array)> { |
1221 | 12.0k | let mut iter = left_arrays.iter().zip(right_arrays.iter()); |
1222 | | |
1223 | 12.0k | let (first_left, first_right) = iter.next().ok_or_else(|| { |
1224 | 0 | DataFusionError::Internal( |
1225 | 0 | "At least one array should be provided for both left and right".to_string(), |
1226 | 0 | ) |
1227 | 12.0k | })?0 ; |
1228 | | |
1229 | 12.0k | let arr_left = take(first_left.as_ref(), indices_left, None)?0 ; |
1230 | 12.0k | let arr_right = take(first_right.as_ref(), indices_right, None)?0 ; |
1231 | | |
1232 | 12.0k | let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?0 ; |
1233 | | |
1234 | | // Use map and try_fold to iterate over the remaining pairs of arrays. |
1235 | | // In each iteration, take is used on the pair of arrays and their equality is determined. |
1236 | | // The results are then folded (combined) using the and function to get a final equality result. |
1237 | 12.0k | equal = iter |
1238 | 12.0k | .map(|(left, right)| {18 |
1239 | 18 | let arr_left = take(left.as_ref(), indices_left, None)?0 ; |
1240 | 18 | let arr_right = take(right.as_ref(), indices_right, None)?0 ; |
1241 | 18 | eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null) |
1242 | 12.0k | }18 ) |
1243 | 12.0k | .try_fold(equal, |acc, equal2| and(&acc, &equal218 ?0 )18 )?0 ; |
1244 | | |
1245 | 12.0k | let filter_builder = FilterBuilder::new(&equal).optimize().build(); |
1246 | | |
1247 | 12.0k | let left_filtered = filter_builder.filter(indices_left)?0 ; |
1248 | 12.0k | let right_filtered = filter_builder.filter(indices_right)?0 ; |
1249 | | |
1250 | 12.0k | Ok(( |
1251 | 12.0k | downcast_array(left_filtered.as_ref()), |
1252 | 12.0k | downcast_array(right_filtered.as_ref()), |
1253 | 12.0k | )) |
1254 | 12.0k | } |
1255 | | |
1256 | 850 | fn get_final_indices_from_shared_bitmap( |
1257 | 850 | shared_bitmap: &SharedBitmapBuilder, |
1258 | 850 | join_type: JoinType, |
1259 | 850 | ) -> (UInt64Array, UInt32Array) { |
1260 | 850 | let bitmap = shared_bitmap.lock(); |
1261 | 850 | get_final_indices_from_bit_map(&bitmap, join_type) |
1262 | 850 | } |
1263 | | |
1264 | | impl HashJoinStream { |
1265 | | /// Separate implementation function that unpins the [`HashJoinStream`] so |
1266 | | /// that partial borrows work correctly |
1267 | 10.1k | fn poll_next_impl( |
1268 | 10.1k | &mut self, |
1269 | 10.1k | cx: &mut std::task::Context<'_>, |
1270 | 10.1k | ) -> Poll<Option<Result<RecordBatch>>> { |
1271 | | loop { |
1272 | 19.0k | return match self.state { |
1273 | | HashJoinStreamState::WaitBuildSide => { |
1274 | 3.25k | handle_state!16 (ready!1.49k (self.collect_build_side(cx))) |
1275 | | } |
1276 | | HashJoinStreamState::FetchProbeBatch => { |
1277 | 7.14k | handle_state!8 (ready!858 (self.fetch_probe_batch(cx))) |
1278 | | } |
1279 | | HashJoinStreamState::ProcessProbeBatch(_) => { |
1280 | 5.16k | handle_state!0 (self.process_probe_batch()) |
1281 | | } |
1282 | | HashJoinStreamState::ExhaustedProbeSide => { |
1283 | 1.73k | handle_state!0 (self.process_unmatched_build_batch()) |
1284 | | } |
1285 | 1.73k | HashJoinStreamState::Completed => Poll::Ready(None), |
1286 | | }; |
1287 | | } |
1288 | 10.1k | } |
1289 | | |
1290 | | /// Collects build-side data by polling `OnceFut` future from initialized build-side |
1291 | | /// |
1292 | | /// Updates build-side to `Ready`, and state to `FetchProbeSide` |
1293 | 3.25k | fn collect_build_side( |
1294 | 3.25k | &mut self, |
1295 | 3.25k | cx: &mut std::task::Context<'_>, |
1296 | 3.25k | ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { |
1297 | 3.25k | let build_timer = self.join_metrics.build_time.timer(); |
1298 | | // build hash table from left (build) side, if not yet done |
1299 | 3.25k | let left_data1.74k = ready!1.49k (self |
1300 | 3.25k | .build_side |
1301 | 3.25k | .try_as_initial_mut()?0 |
1302 | | .left_fut |
1303 | 3.25k | .get_shared(cx))?16 ; |
1304 | 1.74k | build_timer.done(); |
1305 | 1.74k | |
1306 | 1.74k | self.state = HashJoinStreamState::FetchProbeBatch; |
1307 | 1.74k | self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); |
1308 | 1.74k | |
1309 | 1.74k | Poll::Ready(Ok(StatefulStreamResult::Continue)) |
1310 | 3.25k | } |
1311 | | |
1312 | | /// Fetches next batch from probe-side |
1313 | | /// |
1314 | | /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, |
1315 | | /// otherwise updates state to `ExhaustedProbeSide` |
1316 | 7.14k | fn fetch_probe_batch( |
1317 | 7.14k | &mut self, |
1318 | 7.14k | cx: &mut std::task::Context<'_>, |
1319 | 7.14k | ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { |
1320 | 7.14k | match ready!858 (self.right.poll_next_unpin(cx)) { |
1321 | 1.73k | None => { |
1322 | 1.73k | self.state = HashJoinStreamState::ExhaustedProbeSide; |
1323 | 1.73k | } |
1324 | 4.54k | Some(Ok(batch)) => { |
1325 | | // Precalculate hash values for fetched batch |
1326 | 4.54k | let keys_values = self |
1327 | 4.54k | .on_right |
1328 | 4.54k | .iter() |
1329 | 4.55k | .map(|c| c.evaluate(&batch)?0 .into_array(batch.num_rows())) |
1330 | 4.54k | .collect::<Result<Vec<_>>>()?0 ; |
1331 | | |
1332 | 4.54k | self.hashes_buffer.clear(); |
1333 | 4.54k | self.hashes_buffer.resize(batch.num_rows(), 0); |
1334 | 4.54k | create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?0 ; |
1335 | | |
1336 | 4.54k | self.join_metrics.input_batches.add(1); |
1337 | 4.54k | self.join_metrics.input_rows.add(batch.num_rows()); |
1338 | 4.54k | |
1339 | 4.54k | self.state = |
1340 | 4.54k | HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { |
1341 | 4.54k | batch, |
1342 | 4.54k | offset: (0, None), |
1343 | 4.54k | joined_probe_idx: None, |
1344 | 4.54k | }); |
1345 | | } |
1346 | 8 | Some(Err(err)) => return Poll::Ready(Err(err)), |
1347 | | }; |
1348 | | |
1349 | 6.28k | Poll::Ready(Ok(StatefulStreamResult::Continue)) |
1350 | 7.14k | } |
1351 | | |
1352 | | /// Joins current probe batch with build-side data and produces batch with matched output |
1353 | | /// |
1354 | | /// Updates state to `FetchProbeBatch` |
1355 | 5.16k | fn process_probe_batch( |
1356 | 5.16k | &mut self, |
1357 | 5.16k | ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { |
1358 | 5.16k | let state = self.state.try_as_process_probe_batch_mut()?0 ; |
1359 | 5.16k | let build_side = self.build_side.try_as_ready_mut()?0 ; |
1360 | | |
1361 | 5.16k | let timer = self.join_metrics.join_time.timer(); |
1362 | | |
1363 | | // get the matched by join keys indices |
1364 | 5.16k | let (left_indices, right_indices, next_offset) = lookup_join_hashmap( |
1365 | 5.16k | build_side.left_data.hash_map(), |
1366 | 5.16k | build_side.left_data.batch(), |
1367 | 5.16k | &state.batch, |
1368 | 5.16k | &self.on_left, |
1369 | 5.16k | &self.on_right, |
1370 | 5.16k | self.null_equals_null, |
1371 | 5.16k | &self.hashes_buffer, |
1372 | 5.16k | self.batch_size, |
1373 | 5.16k | state.offset, |
1374 | 5.16k | )?0 ; |
1375 | | |
1376 | | // apply join filter if exists |
1377 | 5.16k | let (left_indices, right_indices) = if let Some(filter4.20k ) = &self.filter { |
1378 | 4.20k | apply_join_filter_to_indices( |
1379 | 4.20k | build_side.left_data.batch(), |
1380 | 4.20k | &state.batch, |
1381 | 4.20k | left_indices, |
1382 | 4.20k | right_indices, |
1383 | 4.20k | filter, |
1384 | 4.20k | JoinSide::Left, |
1385 | 4.20k | )?0 |
1386 | | } else { |
1387 | 959 | (left_indices, right_indices) |
1388 | | }; |
1389 | | |
1390 | | // mark joined left-side indices as visited, if required by join type |
1391 | 5.16k | if need_produce_result_in_final(self.join_type) { |
1392 | 2.60k | let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); |
1393 | 5.13k | left_indices.iter().flatten().for_each(|x| { |
1394 | 5.13k | bitmap.set_bit(x as usize, true); |
1395 | 5.13k | }); |
1396 | 2.60k | }2.55k |
1397 | | |
1398 | | // The goals of index alignment for different join types are: |
1399 | | // |
1400 | | // 1) Right & FullJoin -- to append all missing probe-side indices between |
1401 | | // previous (excluding) and current joined indices. |
1402 | | // 2) SemiJoin -- deduplicate probe indices in range between previous |
1403 | | // (excluding) and current joined indices. |
1404 | | // 3) AntiJoin -- return only missing indices in range between |
1405 | | // previous and current joined indices. |
1406 | | // Inclusion/exclusion of the indices themselves don't matter |
1407 | | // |
1408 | | // As a summary -- alignment range can be produced based only on |
1409 | | // joined (matched with filters applied) probe side indices, excluding starting one |
1410 | | // (left from previous iteration). |
1411 | | |
1412 | | // if any rows have been joined -- get last joined probe-side (right) row |
1413 | | // it's important that index counts as "joined" after hash collisions checks |
1414 | | // and join filters applied. |
1415 | 5.16k | let last_joined_right_idx = match right_indices.len() { |
1416 | 2.95k | 0 => None, |
1417 | 2.20k | n => Some(right_indices.value(n - 1) as usize), |
1418 | | }; |
1419 | | |
1420 | | // Calculate range and perform alignment. |
1421 | | // In case probe batch has been processed -- align all remaining rows. |
1422 | 5.16k | let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1599 ); |
1423 | 5.16k | let index_alignment_range_end = if next_offset.is_none() { |
1424 | 4.54k | state.batch.num_rows() |
1425 | | } else { |
1426 | 615 | last_joined_right_idx.map_or(0, |v| v + 1586 ) |
1427 | | }; |
1428 | | |
1429 | 5.16k | let (left_indices, right_indices) = adjust_indices_by_join_type( |
1430 | 5.16k | left_indices, |
1431 | 5.16k | right_indices, |
1432 | 5.16k | index_alignment_range_start..index_alignment_range_end, |
1433 | 5.16k | self.join_type, |
1434 | 5.16k | self.right_side_ordered, |
1435 | 5.16k | ); |
1436 | | |
1437 | 5.16k | let result = build_batch_from_indices( |
1438 | 5.16k | &self.schema, |
1439 | 5.16k | build_side.left_data.batch(), |
1440 | 5.16k | &state.batch, |
1441 | 5.16k | &left_indices, |
1442 | 5.16k | &right_indices, |
1443 | 5.16k | &self.column_indices, |
1444 | 5.16k | JoinSide::Left, |
1445 | 5.16k | )?0 ; |
1446 | | |
1447 | 5.16k | self.join_metrics.output_batches.add(1); |
1448 | 5.16k | self.join_metrics.output_rows.add(result.num_rows()); |
1449 | 5.16k | timer.done(); |
1450 | 5.16k | |
1451 | 5.16k | if next_offset.is_none() { |
1452 | 4.54k | self.state = HashJoinStreamState::FetchProbeBatch; |
1453 | 4.54k | } else { |
1454 | 615 | state.advance( |
1455 | 615 | next_offset |
1456 | 615 | .ok_or_else(|| internal_datafusion_err!("unexpected None offset")0 )?0 , |
1457 | 615 | last_joined_right_idx, |
1458 | | ) |
1459 | | }; |
1460 | | |
1461 | 5.16k | Ok(StatefulStreamResult::Ready(Some(result))) |
1462 | 5.16k | } |
1463 | | |
1464 | | /// Processes unmatched build-side rows for certain join types and produces output batch |
1465 | | /// |
1466 | | /// Updates state to `Completed` |
1467 | 1.73k | fn process_unmatched_build_batch( |
1468 | 1.73k | &mut self, |
1469 | 1.73k | ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { |
1470 | 1.73k | let timer = self.join_metrics.join_time.timer(); |
1471 | 1.73k | |
1472 | 1.73k | if !need_produce_result_in_final(self.join_type) { |
1473 | 870 | self.state = HashJoinStreamState::Completed; |
1474 | 870 | return Ok(StatefulStreamResult::Continue); |
1475 | 862 | } |
1476 | | |
1477 | 862 | let build_side = self.build_side.try_as_ready()?0 ; |
1478 | 862 | if !build_side.left_data.report_probe_completed() { |
1479 | 12 | self.state = HashJoinStreamState::Completed; |
1480 | 12 | return Ok(StatefulStreamResult::Continue); |
1481 | 850 | } |
1482 | 850 | |
1483 | 850 | // use the global left bitmap to produce the left indices and right indices |
1484 | 850 | let (left_side, right_side) = get_final_indices_from_shared_bitmap( |
1485 | 850 | build_side.left_data.visited_indices_bitmap(), |
1486 | 850 | self.join_type, |
1487 | 850 | ); |
1488 | 850 | let empty_right_batch = RecordBatch::new_empty(self.right.schema()); |
1489 | 850 | // use the left and right indices to produce the batch result |
1490 | 850 | let result = build_batch_from_indices( |
1491 | 850 | &self.schema, |
1492 | 850 | build_side.left_data.batch(), |
1493 | 850 | &empty_right_batch, |
1494 | 850 | &left_side, |
1495 | 850 | &right_side, |
1496 | 850 | &self.column_indices, |
1497 | 850 | JoinSide::Left, |
1498 | 850 | ); |
1499 | | |
1500 | 850 | if let Ok(ref batch) = result { |
1501 | 850 | self.join_metrics.input_batches.add(1); |
1502 | 850 | self.join_metrics.input_rows.add(batch.num_rows()); |
1503 | 850 | |
1504 | 850 | self.join_metrics.output_batches.add(1); |
1505 | 850 | self.join_metrics.output_rows.add(batch.num_rows()); |
1506 | 850 | }0 |
1507 | 850 | timer.done(); |
1508 | 850 | |
1509 | 850 | self.state = HashJoinStreamState::Completed; |
1510 | 850 | |
1511 | 850 | Ok(StatefulStreamResult::Ready(Some(result?0 ))) |
1512 | 1.73k | } |
1513 | | } |
1514 | | |
1515 | | impl Stream for HashJoinStream { |
1516 | | type Item = Result<RecordBatch>; |
1517 | | |
1518 | 10.1k | fn poll_next( |
1519 | 10.1k | mut self: std::pin::Pin<&mut Self>, |
1520 | 10.1k | cx: &mut std::task::Context<'_>, |
1521 | 10.1k | ) -> std::task::Poll<Option<Self::Item>> { |
1522 | 10.1k | self.poll_next_impl(cx) |
1523 | 10.1k | } |
1524 | | } |
1525 | | |
1526 | | #[cfg(test)] |
1527 | | mod tests { |
1528 | | use super::*; |
1529 | | use crate::{ |
1530 | | common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, |
1531 | | test::build_table_i32, test::exec::MockExec, |
1532 | | }; |
1533 | | |
1534 | | use arrow::array::{Date32Array, Int32Array}; |
1535 | | use arrow::datatypes::{DataType, Field}; |
1536 | | use arrow_array::StructArray; |
1537 | | use arrow_buffer::NullBuffer; |
1538 | | use datafusion_common::{ |
1539 | | assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, |
1540 | | ScalarValue, |
1541 | | }; |
1542 | | use datafusion_execution::config::SessionConfig; |
1543 | | use datafusion_execution::runtime_env::RuntimeEnvBuilder; |
1544 | | use datafusion_expr::Operator; |
1545 | | use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; |
1546 | | use datafusion_physical_expr::PhysicalExpr; |
1547 | | |
1548 | | use hashbrown::raw::RawTable; |
1549 | | use rstest::*; |
1550 | | use rstest_reuse::*; |
1551 | | |
1552 | 180 | fn div_ceil(a: usize, b: usize) -> usize { |
1553 | 180 | (a + b - 1) / b |
1554 | 180 | } |
1555 | | |
1556 | | #[template] |
1557 | 130 | #[rstest] |
1558 | | fn batch_sizes(#[values(8192, 10, 5, 2, 1)] batch_size: usize) {} |
1559 | | |
1560 | 290 | fn prepare_task_ctx(batch_size: usize) -> Arc<TaskContext> { |
1561 | 290 | let session_config = SessionConfig::default().with_batch_size(batch_size); |
1562 | 290 | Arc::new(TaskContext::default().with_session_config(session_config)) |
1563 | 290 | } |
1564 | | |
1565 | 244 | fn build_table( |
1566 | 244 | a: (&str, &Vec<i32>), |
1567 | 244 | b: (&str, &Vec<i32>), |
1568 | 244 | c: (&str, &Vec<i32>), |
1569 | 244 | ) -> Arc<dyn ExecutionPlan> { |
1570 | 244 | let batch = build_table_i32(a, b, c); |
1571 | 244 | let schema = batch.schema(); |
1572 | 244 | Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) |
1573 | 244 | } |
1574 | | |
1575 | 259 | fn join( |
1576 | 259 | left: Arc<dyn ExecutionPlan>, |
1577 | 259 | right: Arc<dyn ExecutionPlan>, |
1578 | 259 | on: JoinOn, |
1579 | 259 | join_type: &JoinType, |
1580 | 259 | null_equals_null: bool, |
1581 | 259 | ) -> Result<HashJoinExec> { |
1582 | 259 | HashJoinExec::try_new( |
1583 | 259 | left, |
1584 | 259 | right, |
1585 | 259 | on, |
1586 | 259 | None, |
1587 | 259 | join_type, |
1588 | 259 | None, |
1589 | 259 | PartitionMode::CollectLeft, |
1590 | 259 | null_equals_null, |
1591 | 259 | ) |
1592 | 259 | } |
1593 | | |
1594 | 60 | fn join_with_filter( |
1595 | 60 | left: Arc<dyn ExecutionPlan>, |
1596 | 60 | right: Arc<dyn ExecutionPlan>, |
1597 | 60 | on: JoinOn, |
1598 | 60 | filter: JoinFilter, |
1599 | 60 | join_type: &JoinType, |
1600 | 60 | null_equals_null: bool, |
1601 | 60 | ) -> Result<HashJoinExec> { |
1602 | 60 | HashJoinExec::try_new( |
1603 | 60 | left, |
1604 | 60 | right, |
1605 | 60 | on, |
1606 | 60 | Some(filter), |
1607 | 60 | join_type, |
1608 | 60 | None, |
1609 | 60 | PartitionMode::CollectLeft, |
1610 | 60 | null_equals_null, |
1611 | 60 | ) |
1612 | 60 | } |
1613 | | |
1614 | 31 | async fn join_collect( |
1615 | 31 | left: Arc<dyn ExecutionPlan>, |
1616 | 31 | right: Arc<dyn ExecutionPlan>, |
1617 | 31 | on: JoinOn, |
1618 | 31 | join_type: &JoinType, |
1619 | 31 | null_equals_null: bool, |
1620 | 31 | context: Arc<TaskContext>, |
1621 | 31 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
1622 | 31 | let join = join(left, right, on, join_type, null_equals_null)?0 ; |
1623 | 31 | let columns_header = columns(&join.schema()); |
1624 | | |
1625 | 31 | let stream = join.execute(0, context)?0 ; |
1626 | 31 | let batches = common::collect(stream).await6 ?0 ; |
1627 | | |
1628 | 31 | Ok((columns_header, batches)) |
1629 | 31 | } |
1630 | | |
1631 | 15 | async fn partitioned_join_collect( |
1632 | 15 | left: Arc<dyn ExecutionPlan>, |
1633 | 15 | right: Arc<dyn ExecutionPlan>, |
1634 | 15 | on: JoinOn, |
1635 | 15 | join_type: &JoinType, |
1636 | 15 | null_equals_null: bool, |
1637 | 15 | context: Arc<TaskContext>, |
1638 | 15 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
1639 | 15 | join_collect_with_partition_mode( |
1640 | 15 | left, |
1641 | 15 | right, |
1642 | 15 | on, |
1643 | 15 | join_type, |
1644 | 15 | PartitionMode::Partitioned, |
1645 | 15 | null_equals_null, |
1646 | 15 | context, |
1647 | 15 | ) |
1648 | 30 | .await |
1649 | 15 | } |
1650 | | |
1651 | 23 | async fn join_collect_with_partition_mode( |
1652 | 23 | left: Arc<dyn ExecutionPlan>, |
1653 | 23 | right: Arc<dyn ExecutionPlan>, |
1654 | 23 | on: JoinOn, |
1655 | 23 | join_type: &JoinType, |
1656 | 23 | partition_mode: PartitionMode, |
1657 | 23 | null_equals_null: bool, |
1658 | 23 | context: Arc<TaskContext>, |
1659 | 23 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
1660 | 23 | let partition_count = 4; |
1661 | 23 | |
1662 | 23 | let (left_expr, right_expr) = on |
1663 | 23 | .iter() |
1664 | 23 | .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) |
1665 | 23 | .unzip(); |
1666 | | |
1667 | 23 | let left_repartitioned: Arc<dyn ExecutionPlan> = match partition_mode { |
1668 | 8 | PartitionMode::CollectLeft => Arc::new(CoalescePartitionsExec::new(left)), |
1669 | 15 | PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new( |
1670 | 15 | left, |
1671 | 15 | Partitioning::Hash(left_expr, partition_count), |
1672 | 15 | )?0 ), |
1673 | | PartitionMode::Auto => { |
1674 | 0 | return internal_err!("Unexpected PartitionMode::Auto in join tests") |
1675 | | } |
1676 | | }; |
1677 | | |
1678 | 23 | let right_repartitioned: Arc<dyn ExecutionPlan> = match partition_mode { |
1679 | | PartitionMode::CollectLeft => { |
1680 | 8 | let partition_column_name = right.schema().field(0).name().clone(); |
1681 | 8 | let partition_expr = vec![Arc::new(Column::new_with_schema( |
1682 | 8 | &partition_column_name, |
1683 | 8 | &right.schema(), |
1684 | 8 | )?0 ) as _]; |
1685 | 8 | Arc::new(RepartitionExec::try_new( |
1686 | 8 | right, |
1687 | 8 | Partitioning::Hash(partition_expr, partition_count), |
1688 | 8 | )?0 ) as _ |
1689 | | } |
1690 | 15 | PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new( |
1691 | 15 | right, |
1692 | 15 | Partitioning::Hash(right_expr, partition_count), |
1693 | 15 | )?0 ), |
1694 | | PartitionMode::Auto => { |
1695 | 0 | return internal_err!("Unexpected PartitionMode::Auto in join tests") |
1696 | | } |
1697 | | }; |
1698 | | |
1699 | 23 | let join = HashJoinExec::try_new( |
1700 | 23 | left_repartitioned, |
1701 | 23 | right_repartitioned, |
1702 | 23 | on, |
1703 | 23 | None, |
1704 | 23 | join_type, |
1705 | 23 | None, |
1706 | 23 | partition_mode, |
1707 | 23 | null_equals_null, |
1708 | 23 | )?0 ; |
1709 | | |
1710 | 23 | let columns = columns(&join.schema()); |
1711 | 23 | |
1712 | 23 | let mut batches = vec![]; |
1713 | 92 | for i in 0..partition_count23 { |
1714 | 92 | let stream = join.execute(i, Arc::clone(&context))?0 ; |
1715 | 92 | let more_batches = common::collect(stream).await54 ?0 ; |
1716 | 92 | batches.extend( |
1717 | 92 | more_batches |
1718 | 92 | .into_iter() |
1719 | 92 | .filter(|b| b.num_rows() > 082 ) |
1720 | 92 | .collect::<Vec<_>>(), |
1721 | 92 | ); |
1722 | 92 | } |
1723 | | |
1724 | 23 | Ok((columns, batches)) |
1725 | 23 | } |
1726 | | |
1727 | | #[apply(batch_sizes)] |
1728 | | #[tokio::test] |
1729 | | async fn join_inner_one(batch_size: usize) -> Result<()> { |
1730 | | let task_ctx = prepare_task_ctx(batch_size); |
1731 | | let left = build_table( |
1732 | | ("a1", &vec![1, 2, 3]), |
1733 | | ("b1", &vec![4, 5, 5]), // this has a repetition |
1734 | | ("c1", &vec![7, 8, 9]), |
1735 | | ); |
1736 | | let right = build_table( |
1737 | | ("a2", &vec![10, 20, 30]), |
1738 | | ("b1", &vec![4, 5, 6]), |
1739 | | ("c2", &vec![70, 80, 90]), |
1740 | | ); |
1741 | | |
1742 | | let on = vec![( |
1743 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
1744 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
1745 | | )]; |
1746 | | |
1747 | | let (columns, batches) = join_collect( |
1748 | | Arc::clone(&left), |
1749 | | Arc::clone(&right), |
1750 | | on.clone(), |
1751 | | &JoinType::Inner, |
1752 | | false, |
1753 | | task_ctx, |
1754 | | ) |
1755 | | .await?; |
1756 | | |
1757 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
1758 | | |
1759 | | let expected = [ |
1760 | | "+----+----+----+----+----+----+", |
1761 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
1762 | | "+----+----+----+----+----+----+", |
1763 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
1764 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
1765 | | "| 3 | 5 | 9 | 20 | 5 | 80 |", |
1766 | | "+----+----+----+----+----+----+", |
1767 | | ]; |
1768 | | |
1769 | | // Inner join output is expected to preserve both inputs order |
1770 | | assert_batches_eq!(expected, &batches); |
1771 | | |
1772 | | Ok(()) |
1773 | | } |
1774 | | |
1775 | | #[apply(batch_sizes)] |
1776 | | #[tokio::test] |
1777 | | async fn partitioned_join_inner_one(batch_size: usize) -> Result<()> { |
1778 | | let task_ctx = prepare_task_ctx(batch_size); |
1779 | | let left = build_table( |
1780 | | ("a1", &vec![1, 2, 3]), |
1781 | | ("b1", &vec![4, 5, 5]), // this has a repetition |
1782 | | ("c1", &vec![7, 8, 9]), |
1783 | | ); |
1784 | | let right = build_table( |
1785 | | ("a2", &vec![10, 20, 30]), |
1786 | | ("b1", &vec![4, 5, 6]), |
1787 | | ("c2", &vec![70, 80, 90]), |
1788 | | ); |
1789 | | let on = vec![( |
1790 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
1791 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
1792 | | )]; |
1793 | | |
1794 | | let (columns, batches) = partitioned_join_collect( |
1795 | | Arc::clone(&left), |
1796 | | Arc::clone(&right), |
1797 | | on.clone(), |
1798 | | &JoinType::Inner, |
1799 | | false, |
1800 | | task_ctx, |
1801 | | ) |
1802 | | .await?; |
1803 | | |
1804 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
1805 | | |
1806 | | let expected = [ |
1807 | | "+----+----+----+----+----+----+", |
1808 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
1809 | | "+----+----+----+----+----+----+", |
1810 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
1811 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
1812 | | "| 3 | 5 | 9 | 20 | 5 | 80 |", |
1813 | | "+----+----+----+----+----+----+", |
1814 | | ]; |
1815 | | assert_batches_sorted_eq!(expected, &batches); |
1816 | | |
1817 | | Ok(()) |
1818 | | } |
1819 | | |
1820 | | #[tokio::test] |
1821 | 1 | async fn join_inner_one_no_shared_column_names() -> Result<()> { |
1822 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1823 | 1 | let left = build_table( |
1824 | 1 | ("a1", &vec![1, 2, 3]), |
1825 | 1 | ("b1", &vec![4, 5, 5]), // this has a repetition |
1826 | 1 | ("c1", &vec![7, 8, 9]), |
1827 | 1 | ); |
1828 | 1 | let right = build_table( |
1829 | 1 | ("a2", &vec![10, 20, 30]), |
1830 | 1 | ("b2", &vec![4, 5, 6]), |
1831 | 1 | ("c2", &vec![70, 80, 90]), |
1832 | 1 | ); |
1833 | 1 | let on = vec![( |
1834 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
1835 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
1836 | 1 | )]; |
1837 | 1 | |
1838 | 1 | let (columns, batches) = |
1839 | 1 | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await0 ?0 ; |
1840 | 1 | |
1841 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
1842 | 1 | |
1843 | 1 | let expected = [ |
1844 | 1 | "+----+----+----+----+----+----+", |
1845 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
1846 | 1 | "+----+----+----+----+----+----+", |
1847 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
1848 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
1849 | 1 | "| 3 | 5 | 9 | 20 | 5 | 80 |", |
1850 | 1 | "+----+----+----+----+----+----+", |
1851 | 1 | ]; |
1852 | 1 | |
1853 | 1 | // Inner join output is expected to preserve both inputs order |
1854 | 1 | assert_batches_eq!(expected, &batches); |
1855 | 1 | |
1856 | 1 | Ok(()) |
1857 | 1 | } |
1858 | | |
1859 | | #[tokio::test] |
1860 | 1 | async fn join_inner_one_randomly_ordered() -> Result<()> { |
1861 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1862 | 1 | let left = build_table( |
1863 | 1 | ("a1", &vec![0, 3, 2, 1]), |
1864 | 1 | ("b1", &vec![4, 5, 5, 4]), |
1865 | 1 | ("c1", &vec![6, 9, 8, 7]), |
1866 | 1 | ); |
1867 | 1 | let right = build_table( |
1868 | 1 | ("a2", &vec![20, 30, 10]), |
1869 | 1 | ("b2", &vec![5, 6, 4]), |
1870 | 1 | ("c2", &vec![80, 90, 70]), |
1871 | 1 | ); |
1872 | 1 | let on = vec![( |
1873 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
1874 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
1875 | 1 | )]; |
1876 | 1 | |
1877 | 1 | let (columns, batches) = |
1878 | 1 | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await0 ?0 ; |
1879 | 1 | |
1880 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
1881 | 1 | |
1882 | 1 | let expected = [ |
1883 | 1 | "+----+----+----+----+----+----+", |
1884 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
1885 | 1 | "+----+----+----+----+----+----+", |
1886 | 1 | "| 3 | 5 | 9 | 20 | 5 | 80 |", |
1887 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
1888 | 1 | "| 0 | 4 | 6 | 10 | 4 | 70 |", |
1889 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
1890 | 1 | "+----+----+----+----+----+----+", |
1891 | 1 | ]; |
1892 | 1 | |
1893 | 1 | // Inner join output is expected to preserve both inputs order |
1894 | 1 | assert_batches_eq!(expected, &batches); |
1895 | 1 | |
1896 | 1 | Ok(()) |
1897 | 1 | } |
1898 | | |
1899 | | #[apply(batch_sizes)] |
1900 | | #[tokio::test] |
1901 | | async fn join_inner_two(batch_size: usize) -> Result<()> { |
1902 | | let task_ctx = prepare_task_ctx(batch_size); |
1903 | | let left = build_table( |
1904 | | ("a1", &vec![1, 2, 2]), |
1905 | | ("b2", &vec![1, 2, 2]), |
1906 | | ("c1", &vec![7, 8, 9]), |
1907 | | ); |
1908 | | let right = build_table( |
1909 | | ("a1", &vec![1, 2, 3]), |
1910 | | ("b2", &vec![1, 2, 2]), |
1911 | | ("c2", &vec![70, 80, 90]), |
1912 | | ); |
1913 | | let on = vec![ |
1914 | | ( |
1915 | | Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, |
1916 | | Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, |
1917 | | ), |
1918 | | ( |
1919 | | Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, |
1920 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
1921 | | ), |
1922 | | ]; |
1923 | | |
1924 | | let (columns, batches) = |
1925 | | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; |
1926 | | |
1927 | | assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); |
1928 | | |
1929 | | let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { |
1930 | | // Expected number of hash table matches = 3 |
1931 | | // in case batch_size is 1 - additional empty batch for remaining 3-2 row |
1932 | | let mut expected_batch_count = div_ceil(3, batch_size); |
1933 | | if batch_size == 1 { |
1934 | | expected_batch_count += 1; |
1935 | | } |
1936 | | expected_batch_count |
1937 | | } else { |
1938 | | // With hash collisions enabled, all records will match each other |
1939 | | // and filtered later. |
1940 | | div_ceil(9, batch_size) |
1941 | | }; |
1942 | | |
1943 | | assert_eq!(batches.len(), expected_batch_count); |
1944 | | |
1945 | | let expected = [ |
1946 | | "+----+----+----+----+----+----+", |
1947 | | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
1948 | | "+----+----+----+----+----+----+", |
1949 | | "| 1 | 1 | 7 | 1 | 1 | 70 |", |
1950 | | "| 2 | 2 | 8 | 2 | 2 | 80 |", |
1951 | | "| 2 | 2 | 9 | 2 | 2 | 80 |", |
1952 | | "+----+----+----+----+----+----+", |
1953 | | ]; |
1954 | | |
1955 | | // Inner join output is expected to preserve both inputs order |
1956 | | assert_batches_eq!(expected, &batches); |
1957 | | |
1958 | | Ok(()) |
1959 | | } |
1960 | | |
1961 | | /// Test where the left has 2 parts, the right with 1 part => 1 part |
1962 | | #[apply(batch_sizes)] |
1963 | | #[tokio::test] |
1964 | | async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> { |
1965 | | let task_ctx = prepare_task_ctx(batch_size); |
1966 | | let batch1 = build_table_i32( |
1967 | | ("a1", &vec![1, 2]), |
1968 | | ("b2", &vec![1, 2]), |
1969 | | ("c1", &vec![7, 8]), |
1970 | | ); |
1971 | | let batch2 = |
1972 | | build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9])); |
1973 | | let schema = batch1.schema(); |
1974 | | let left = Arc::new( |
1975 | | MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), |
1976 | | ); |
1977 | | |
1978 | | let right = build_table( |
1979 | | ("a1", &vec![1, 2, 3]), |
1980 | | ("b2", &vec![1, 2, 2]), |
1981 | | ("c2", &vec![70, 80, 90]), |
1982 | | ); |
1983 | | let on = vec![ |
1984 | | ( |
1985 | | Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, |
1986 | | Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, |
1987 | | ), |
1988 | | ( |
1989 | | Arc::new(Column::new_with_schema("b2", &left.schema())?) as _, |
1990 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
1991 | | ), |
1992 | | ]; |
1993 | | |
1994 | | let (columns, batches) = |
1995 | | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; |
1996 | | |
1997 | | assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); |
1998 | | |
1999 | | let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { |
2000 | | // Expected number of hash table matches = 3 |
2001 | | // in case batch_size is 1 - additional empty batch for remaining 3-2 row |
2002 | | let mut expected_batch_count = div_ceil(3, batch_size); |
2003 | | if batch_size == 1 { |
2004 | | expected_batch_count += 1; |
2005 | | } |
2006 | | expected_batch_count |
2007 | | } else { |
2008 | | // With hash collisions enabled, all records will match each other |
2009 | | // and filtered later. |
2010 | | div_ceil(9, batch_size) |
2011 | | }; |
2012 | | |
2013 | | assert_eq!(batches.len(), expected_batch_count); |
2014 | | |
2015 | | let expected = [ |
2016 | | "+----+----+----+----+----+----+", |
2017 | | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
2018 | | "+----+----+----+----+----+----+", |
2019 | | "| 1 | 1 | 7 | 1 | 1 | 70 |", |
2020 | | "| 2 | 2 | 8 | 2 | 2 | 80 |", |
2021 | | "| 2 | 2 | 9 | 2 | 2 | 80 |", |
2022 | | "+----+----+----+----+----+----+", |
2023 | | ]; |
2024 | | |
2025 | | // Inner join output is expected to preserve both inputs order |
2026 | | assert_batches_eq!(expected, &batches); |
2027 | | |
2028 | | Ok(()) |
2029 | | } |
2030 | | |
2031 | | #[tokio::test] |
2032 | 1 | async fn join_inner_one_two_parts_left_randomly_ordered() -> Result<()> { |
2033 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
2034 | 1 | let batch1 = build_table_i32( |
2035 | 1 | ("a1", &vec![0, 3]), |
2036 | 1 | ("b1", &vec![4, 5]), |
2037 | 1 | ("c1", &vec![6, 9]), |
2038 | 1 | ); |
2039 | 1 | let batch2 = build_table_i32( |
2040 | 1 | ("a1", &vec![2, 1]), |
2041 | 1 | ("b1", &vec![5, 4]), |
2042 | 1 | ("c1", &vec![8, 7]), |
2043 | 1 | ); |
2044 | 1 | let schema = batch1.schema(); |
2045 | 1 | |
2046 | 1 | let left = Arc::new( |
2047 | 1 | MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), |
2048 | 1 | ); |
2049 | 1 | let right = build_table( |
2050 | 1 | ("a2", &vec![20, 30, 10]), |
2051 | 1 | ("b2", &vec![5, 6, 4]), |
2052 | 1 | ("c2", &vec![80, 90, 70]), |
2053 | 1 | ); |
2054 | 1 | let on = vec![( |
2055 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2056 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2057 | 1 | )]; |
2058 | 1 | |
2059 | 1 | let (columns, batches) = |
2060 | 1 | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?0 ; |
2061 | 1 | |
2062 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
2063 | 1 | |
2064 | 1 | let expected = [ |
2065 | 1 | "+----+----+----+----+----+----+", |
2066 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2067 | 1 | "+----+----+----+----+----+----+", |
2068 | 1 | "| 3 | 5 | 9 | 20 | 5 | 80 |", |
2069 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2070 | 1 | "| 0 | 4 | 6 | 10 | 4 | 70 |", |
2071 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2072 | 1 | "+----+----+----+----+----+----+", |
2073 | 1 | ]; |
2074 | 1 | |
2075 | 1 | // Inner join output is expected to preserve both inputs order |
2076 | 1 | assert_batches_eq!(expected, &batches); |
2077 | 1 | |
2078 | 1 | Ok(()) |
2079 | 1 | } |
2080 | | |
2081 | | /// Test where the left has 1 part, the right has 2 parts => 2 parts |
2082 | | #[apply(batch_sizes)] |
2083 | | #[tokio::test] |
2084 | | async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> { |
2085 | | let task_ctx = prepare_task_ctx(batch_size); |
2086 | | let left = build_table( |
2087 | | ("a1", &vec![1, 2, 3]), |
2088 | | ("b1", &vec![4, 5, 5]), // this has a repetition |
2089 | | ("c1", &vec![7, 8, 9]), |
2090 | | ); |
2091 | | |
2092 | | let batch1 = build_table_i32( |
2093 | | ("a2", &vec![10, 20]), |
2094 | | ("b1", &vec![4, 6]), |
2095 | | ("c2", &vec![70, 80]), |
2096 | | ); |
2097 | | let batch2 = |
2098 | | build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); |
2099 | | let schema = batch1.schema(); |
2100 | | let right = Arc::new( |
2101 | | MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), |
2102 | | ); |
2103 | | |
2104 | | let on = vec![( |
2105 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2106 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
2107 | | )]; |
2108 | | |
2109 | | let join = join(left, right, on, &JoinType::Inner, false)?; |
2110 | | |
2111 | | let columns = columns(&join.schema()); |
2112 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
2113 | | |
2114 | | // first part |
2115 | | let stream = join.execute(0, Arc::clone(&task_ctx))?; |
2116 | | let batches = common::collect(stream).await?; |
2117 | | |
2118 | | let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { |
2119 | | // Expected number of hash table matches for first right batch = 1 |
2120 | | // and additional empty batch for non-joined 20-6-80 |
2121 | | let mut expected_batch_count = div_ceil(1, batch_size); |
2122 | | if batch_size == 1 { |
2123 | | expected_batch_count += 1; |
2124 | | } |
2125 | | expected_batch_count |
2126 | | } else { |
2127 | | // With hash collisions enabled, all records will match each other |
2128 | | // and filtered later. |
2129 | | div_ceil(6, batch_size) |
2130 | | }; |
2131 | | assert_eq!(batches.len(), expected_batch_count); |
2132 | | |
2133 | | let expected = [ |
2134 | | "+----+----+----+----+----+----+", |
2135 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2136 | | "+----+----+----+----+----+----+", |
2137 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2138 | | "+----+----+----+----+----+----+", |
2139 | | ]; |
2140 | | |
2141 | | // Inner join output is expected to preserve both inputs order |
2142 | | assert_batches_eq!(expected, &batches); |
2143 | | |
2144 | | // second part |
2145 | | let stream = join.execute(1, Arc::clone(&task_ctx))?; |
2146 | | let batches = common::collect(stream).await?; |
2147 | | |
2148 | | let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { |
2149 | | // Expected number of hash table matches for second right batch = 2 |
2150 | | div_ceil(2, batch_size) |
2151 | | } else { |
2152 | | // With hash collisions enabled, all records will match each other |
2153 | | // and filtered later. |
2154 | | div_ceil(3, batch_size) |
2155 | | }; |
2156 | | assert_eq!(batches.len(), expected_batch_count); |
2157 | | |
2158 | | let expected = [ |
2159 | | "+----+----+----+----+----+----+", |
2160 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2161 | | "+----+----+----+----+----+----+", |
2162 | | "| 2 | 5 | 8 | 30 | 5 | 90 |", |
2163 | | "| 3 | 5 | 9 | 30 | 5 | 90 |", |
2164 | | "+----+----+----+----+----+----+", |
2165 | | ]; |
2166 | | |
2167 | | // Inner join output is expected to preserve both inputs order |
2168 | | assert_batches_eq!(expected, &batches); |
2169 | | |
2170 | | Ok(()) |
2171 | | } |
2172 | | |
2173 | 10 | fn build_table_two_batches( |
2174 | 10 | a: (&str, &Vec<i32>), |
2175 | 10 | b: (&str, &Vec<i32>), |
2176 | 10 | c: (&str, &Vec<i32>), |
2177 | 10 | ) -> Arc<dyn ExecutionPlan> { |
2178 | 10 | let batch = build_table_i32(a, b, c); |
2179 | 10 | let schema = batch.schema(); |
2180 | 10 | Arc::new( |
2181 | 10 | MemoryExec::try_new(&[vec![batch.clone(), batch]], schema, None).unwrap(), |
2182 | 10 | ) |
2183 | 10 | } |
2184 | | |
2185 | | #[apply(batch_sizes)] |
2186 | | #[tokio::test] |
2187 | | async fn join_left_multi_batch(batch_size: usize) { |
2188 | | let task_ctx = prepare_task_ctx(batch_size); |
2189 | | let left = build_table( |
2190 | | ("a1", &vec![1, 2, 3]), |
2191 | | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
2192 | | ("c1", &vec![7, 8, 9]), |
2193 | | ); |
2194 | | let right = build_table_two_batches( |
2195 | | ("a2", &vec![10, 20, 30]), |
2196 | | ("b1", &vec![4, 5, 6]), |
2197 | | ("c2", &vec![70, 80, 90]), |
2198 | | ); |
2199 | | let on = vec![( |
2200 | | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
2201 | | Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, |
2202 | | )]; |
2203 | | |
2204 | | let join = join(left, right, on, &JoinType::Left, false).unwrap(); |
2205 | | |
2206 | | let columns = columns(&join.schema()); |
2207 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
2208 | | |
2209 | | let stream = join.execute(0, task_ctx).unwrap(); |
2210 | | let batches = common::collect(stream).await.unwrap(); |
2211 | | |
2212 | | let expected = [ |
2213 | | "+----+----+----+----+----+----+", |
2214 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2215 | | "+----+----+----+----+----+----+", |
2216 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2217 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2218 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2219 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2220 | | "| 3 | 7 | 9 | | | |", |
2221 | | "+----+----+----+----+----+----+", |
2222 | | ]; |
2223 | | |
2224 | | assert_batches_sorted_eq!(expected, &batches); |
2225 | | } |
2226 | | |
2227 | | #[apply(batch_sizes)] |
2228 | | #[tokio::test] |
2229 | | async fn join_full_multi_batch(batch_size: usize) { |
2230 | | let task_ctx = prepare_task_ctx(batch_size); |
2231 | | let left = build_table( |
2232 | | ("a1", &vec![1, 2, 3]), |
2233 | | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
2234 | | ("c1", &vec![7, 8, 9]), |
2235 | | ); |
2236 | | // create two identical batches for the right side |
2237 | | let right = build_table_two_batches( |
2238 | | ("a2", &vec![10, 20, 30]), |
2239 | | ("b2", &vec![4, 5, 6]), |
2240 | | ("c2", &vec![70, 80, 90]), |
2241 | | ); |
2242 | | let on = vec![( |
2243 | | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
2244 | | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
2245 | | )]; |
2246 | | |
2247 | | let join = join(left, right, on, &JoinType::Full, false).unwrap(); |
2248 | | |
2249 | | let columns = columns(&join.schema()); |
2250 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
2251 | | |
2252 | | let stream = join.execute(0, task_ctx).unwrap(); |
2253 | | let batches = common::collect(stream).await.unwrap(); |
2254 | | |
2255 | | let expected = [ |
2256 | | "+----+----+----+----+----+----+", |
2257 | | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2258 | | "+----+----+----+----+----+----+", |
2259 | | "| | | | 30 | 6 | 90 |", |
2260 | | "| | | | 30 | 6 | 90 |", |
2261 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2262 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2263 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2264 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2265 | | "| 3 | 7 | 9 | | | |", |
2266 | | "+----+----+----+----+----+----+", |
2267 | | ]; |
2268 | | |
2269 | | assert_batches_sorted_eq!(expected, &batches); |
2270 | | } |
2271 | | |
2272 | | #[apply(batch_sizes)] |
2273 | | #[tokio::test] |
2274 | | async fn join_left_empty_right(batch_size: usize) { |
2275 | | let task_ctx = prepare_task_ctx(batch_size); |
2276 | | let left = build_table( |
2277 | | ("a1", &vec![1, 2, 3]), |
2278 | | ("b1", &vec![4, 5, 7]), |
2279 | | ("c1", &vec![7, 8, 9]), |
2280 | | ); |
2281 | | let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); |
2282 | | let on = vec![( |
2283 | | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
2284 | | Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, |
2285 | | )]; |
2286 | | let schema = right.schema(); |
2287 | | let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); |
2288 | | let join = join(left, right, on, &JoinType::Left, false).unwrap(); |
2289 | | |
2290 | | let columns = columns(&join.schema()); |
2291 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
2292 | | |
2293 | | let stream = join.execute(0, task_ctx).unwrap(); |
2294 | | let batches = common::collect(stream).await.unwrap(); |
2295 | | |
2296 | | let expected = [ |
2297 | | "+----+----+----+----+----+----+", |
2298 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2299 | | "+----+----+----+----+----+----+", |
2300 | | "| 1 | 4 | 7 | | | |", |
2301 | | "| 2 | 5 | 8 | | | |", |
2302 | | "| 3 | 7 | 9 | | | |", |
2303 | | "+----+----+----+----+----+----+", |
2304 | | ]; |
2305 | | |
2306 | | assert_batches_sorted_eq!(expected, &batches); |
2307 | | } |
2308 | | |
2309 | | #[apply(batch_sizes)] |
2310 | | #[tokio::test] |
2311 | | async fn join_full_empty_right(batch_size: usize) { |
2312 | | let task_ctx = prepare_task_ctx(batch_size); |
2313 | | let left = build_table( |
2314 | | ("a1", &vec![1, 2, 3]), |
2315 | | ("b1", &vec![4, 5, 7]), |
2316 | | ("c1", &vec![7, 8, 9]), |
2317 | | ); |
2318 | | let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); |
2319 | | let on = vec![( |
2320 | | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
2321 | | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
2322 | | )]; |
2323 | | let schema = right.schema(); |
2324 | | let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); |
2325 | | let join = join(left, right, on, &JoinType::Full, false).unwrap(); |
2326 | | |
2327 | | let columns = columns(&join.schema()); |
2328 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
2329 | | |
2330 | | let stream = join.execute(0, task_ctx).unwrap(); |
2331 | | let batches = common::collect(stream).await.unwrap(); |
2332 | | |
2333 | | let expected = [ |
2334 | | "+----+----+----+----+----+----+", |
2335 | | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2336 | | "+----+----+----+----+----+----+", |
2337 | | "| 1 | 4 | 7 | | | |", |
2338 | | "| 2 | 5 | 8 | | | |", |
2339 | | "| 3 | 7 | 9 | | | |", |
2340 | | "+----+----+----+----+----+----+", |
2341 | | ]; |
2342 | | |
2343 | | assert_batches_sorted_eq!(expected, &batches); |
2344 | | } |
2345 | | |
2346 | | #[apply(batch_sizes)] |
2347 | | #[tokio::test] |
2348 | | async fn join_left_one(batch_size: usize) -> Result<()> { |
2349 | | let task_ctx = prepare_task_ctx(batch_size); |
2350 | | let left = build_table( |
2351 | | ("a1", &vec![1, 2, 3]), |
2352 | | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
2353 | | ("c1", &vec![7, 8, 9]), |
2354 | | ); |
2355 | | let right = build_table( |
2356 | | ("a2", &vec![10, 20, 30]), |
2357 | | ("b1", &vec![4, 5, 6]), |
2358 | | ("c2", &vec![70, 80, 90]), |
2359 | | ); |
2360 | | let on = vec![( |
2361 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2362 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
2363 | | )]; |
2364 | | |
2365 | | let (columns, batches) = join_collect( |
2366 | | Arc::clone(&left), |
2367 | | Arc::clone(&right), |
2368 | | on.clone(), |
2369 | | &JoinType::Left, |
2370 | | false, |
2371 | | task_ctx, |
2372 | | ) |
2373 | | .await?; |
2374 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
2375 | | |
2376 | | let expected = [ |
2377 | | "+----+----+----+----+----+----+", |
2378 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2379 | | "+----+----+----+----+----+----+", |
2380 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2381 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2382 | | "| 3 | 7 | 9 | | | |", |
2383 | | "+----+----+----+----+----+----+", |
2384 | | ]; |
2385 | | assert_batches_sorted_eq!(expected, &batches); |
2386 | | |
2387 | | Ok(()) |
2388 | | } |
2389 | | |
2390 | | #[apply(batch_sizes)] |
2391 | | #[tokio::test] |
2392 | | async fn partitioned_join_left_one(batch_size: usize) -> Result<()> { |
2393 | | let task_ctx = prepare_task_ctx(batch_size); |
2394 | | let left = build_table( |
2395 | | ("a1", &vec![1, 2, 3]), |
2396 | | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
2397 | | ("c1", &vec![7, 8, 9]), |
2398 | | ); |
2399 | | let right = build_table( |
2400 | | ("a2", &vec![10, 20, 30]), |
2401 | | ("b1", &vec![4, 5, 6]), |
2402 | | ("c2", &vec![70, 80, 90]), |
2403 | | ); |
2404 | | let on = vec![( |
2405 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2406 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
2407 | | )]; |
2408 | | |
2409 | | let (columns, batches) = partitioned_join_collect( |
2410 | | Arc::clone(&left), |
2411 | | Arc::clone(&right), |
2412 | | on.clone(), |
2413 | | &JoinType::Left, |
2414 | | false, |
2415 | | task_ctx, |
2416 | | ) |
2417 | | .await?; |
2418 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
2419 | | |
2420 | | let expected = [ |
2421 | | "+----+----+----+----+----+----+", |
2422 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2423 | | "+----+----+----+----+----+----+", |
2424 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2425 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2426 | | "| 3 | 7 | 9 | | | |", |
2427 | | "+----+----+----+----+----+----+", |
2428 | | ]; |
2429 | | assert_batches_sorted_eq!(expected, &batches); |
2430 | | |
2431 | | Ok(()) |
2432 | | } |
2433 | | |
2434 | 40 | fn build_semi_anti_left_table() -> Arc<dyn ExecutionPlan> { |
2435 | 40 | // just two line match |
2436 | 40 | // b1 = 10 |
2437 | 40 | build_table( |
2438 | 40 | ("a1", &vec![1, 3, 5, 7, 9, 11, 13]), |
2439 | 40 | ("b1", &vec![1, 3, 5, 7, 8, 8, 10]), |
2440 | 40 | ("c1", &vec![10, 30, 50, 70, 90, 110, 130]), |
2441 | 40 | ) |
2442 | 40 | } |
2443 | | |
2444 | 40 | fn build_semi_anti_right_table() -> Arc<dyn ExecutionPlan> { |
2445 | 40 | // just two line match |
2446 | 40 | // b2 = 10 |
2447 | 40 | build_table( |
2448 | 40 | ("a2", &vec![8, 12, 6, 2, 10, 4]), |
2449 | 40 | ("b2", &vec![8, 10, 6, 2, 10, 4]), |
2450 | 40 | ("c2", &vec![20, 40, 60, 80, 100, 120]), |
2451 | 40 | ) |
2452 | 40 | } |
2453 | | |
2454 | | #[apply(batch_sizes)] |
2455 | | #[tokio::test] |
2456 | | async fn join_left_semi(batch_size: usize) -> Result<()> { |
2457 | | let task_ctx = prepare_task_ctx(batch_size); |
2458 | | let left = build_semi_anti_left_table(); |
2459 | | let right = build_semi_anti_right_table(); |
2460 | | // left_table left semi join right_table on left_table.b1 = right_table.b2 |
2461 | | let on = vec![( |
2462 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2463 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2464 | | )]; |
2465 | | |
2466 | | let join = join(left, right, on, &JoinType::LeftSemi, false)?; |
2467 | | |
2468 | | let columns = columns(&join.schema()); |
2469 | | assert_eq!(columns, vec!["a1", "b1", "c1"]); |
2470 | | |
2471 | | let stream = join.execute(0, task_ctx)?; |
2472 | | let batches = common::collect(stream).await?; |
2473 | | |
2474 | | // ignore the order |
2475 | | let expected = [ |
2476 | | "+----+----+-----+", |
2477 | | "| a1 | b1 | c1 |", |
2478 | | "+----+----+-----+", |
2479 | | "| 11 | 8 | 110 |", |
2480 | | "| 13 | 10 | 130 |", |
2481 | | "| 9 | 8 | 90 |", |
2482 | | "+----+----+-----+", |
2483 | | ]; |
2484 | | assert_batches_sorted_eq!(expected, &batches); |
2485 | | |
2486 | | Ok(()) |
2487 | | } |
2488 | | |
2489 | | #[apply(batch_sizes)] |
2490 | | #[tokio::test] |
2491 | | async fn join_left_semi_with_filter(batch_size: usize) -> Result<()> { |
2492 | | let task_ctx = prepare_task_ctx(batch_size); |
2493 | | let left = build_semi_anti_left_table(); |
2494 | | let right = build_semi_anti_right_table(); |
2495 | | |
2496 | | // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 10 |
2497 | | let on = vec![( |
2498 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2499 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2500 | | )]; |
2501 | | |
2502 | | let column_indices = vec![ColumnIndex { |
2503 | | index: 0, |
2504 | | side: JoinSide::Right, |
2505 | | }]; |
2506 | | let intermediate_schema = |
2507 | | Schema::new(vec![Field::new("x", DataType::Int32, true)]); |
2508 | | |
2509 | | let filter_expression = Arc::new(BinaryExpr::new( |
2510 | | Arc::new(Column::new("x", 0)), |
2511 | | Operator::NotEq, |
2512 | | Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), |
2513 | | )) as Arc<dyn PhysicalExpr>; |
2514 | | |
2515 | | let filter = JoinFilter::new( |
2516 | | filter_expression, |
2517 | | column_indices.clone(), |
2518 | | intermediate_schema.clone(), |
2519 | | ); |
2520 | | |
2521 | | let join = join_with_filter( |
2522 | | Arc::clone(&left), |
2523 | | Arc::clone(&right), |
2524 | | on.clone(), |
2525 | | filter, |
2526 | | &JoinType::LeftSemi, |
2527 | | false, |
2528 | | )?; |
2529 | | |
2530 | | let columns_header = columns(&join.schema()); |
2531 | | assert_eq!(columns_header.clone(), vec!["a1", "b1", "c1"]); |
2532 | | |
2533 | | let stream = join.execute(0, Arc::clone(&task_ctx))?; |
2534 | | let batches = common::collect(stream).await?; |
2535 | | |
2536 | | let expected = [ |
2537 | | "+----+----+-----+", |
2538 | | "| a1 | b1 | c1 |", |
2539 | | "+----+----+-----+", |
2540 | | "| 11 | 8 | 110 |", |
2541 | | "| 13 | 10 | 130 |", |
2542 | | "| 9 | 8 | 90 |", |
2543 | | "+----+----+-----+", |
2544 | | ]; |
2545 | | assert_batches_sorted_eq!(expected, &batches); |
2546 | | |
2547 | | // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 > 10 |
2548 | | let filter_expression = Arc::new(BinaryExpr::new( |
2549 | | Arc::new(Column::new("x", 0)), |
2550 | | Operator::Gt, |
2551 | | Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), |
2552 | | )) as Arc<dyn PhysicalExpr>; |
2553 | | let filter = |
2554 | | JoinFilter::new(filter_expression, column_indices, intermediate_schema); |
2555 | | |
2556 | | let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?; |
2557 | | |
2558 | | let columns_header = columns(&join.schema()); |
2559 | | assert_eq!(columns_header, vec!["a1", "b1", "c1"]); |
2560 | | |
2561 | | let stream = join.execute(0, task_ctx)?; |
2562 | | let batches = common::collect(stream).await?; |
2563 | | |
2564 | | let expected = [ |
2565 | | "+----+----+-----+", |
2566 | | "| a1 | b1 | c1 |", |
2567 | | "+----+----+-----+", |
2568 | | "| 13 | 10 | 130 |", |
2569 | | "+----+----+-----+", |
2570 | | ]; |
2571 | | assert_batches_sorted_eq!(expected, &batches); |
2572 | | |
2573 | | Ok(()) |
2574 | | } |
2575 | | |
2576 | | #[apply(batch_sizes)] |
2577 | | #[tokio::test] |
2578 | | async fn join_right_semi(batch_size: usize) -> Result<()> { |
2579 | | let task_ctx = prepare_task_ctx(batch_size); |
2580 | | let left = build_semi_anti_left_table(); |
2581 | | let right = build_semi_anti_right_table(); |
2582 | | |
2583 | | // left_table right semi join right_table on left_table.b1 = right_table.b2 |
2584 | | let on = vec![( |
2585 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2586 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2587 | | )]; |
2588 | | |
2589 | | let join = join(left, right, on, &JoinType::RightSemi, false)?; |
2590 | | |
2591 | | let columns = columns(&join.schema()); |
2592 | | assert_eq!(columns, vec!["a2", "b2", "c2"]); |
2593 | | |
2594 | | let stream = join.execute(0, task_ctx)?; |
2595 | | let batches = common::collect(stream).await?; |
2596 | | |
2597 | | let expected = [ |
2598 | | "+----+----+-----+", |
2599 | | "| a2 | b2 | c2 |", |
2600 | | "+----+----+-----+", |
2601 | | "| 8 | 8 | 20 |", |
2602 | | "| 12 | 10 | 40 |", |
2603 | | "| 10 | 10 | 100 |", |
2604 | | "+----+----+-----+", |
2605 | | ]; |
2606 | | |
2607 | | // RightSemi join output is expected to preserve right input order |
2608 | | assert_batches_eq!(expected, &batches); |
2609 | | |
2610 | | Ok(()) |
2611 | | } |
2612 | | |
2613 | | #[apply(batch_sizes)] |
2614 | | #[tokio::test] |
2615 | | async fn join_right_semi_with_filter(batch_size: usize) -> Result<()> { |
2616 | | let task_ctx = prepare_task_ctx(batch_size); |
2617 | | let left = build_semi_anti_left_table(); |
2618 | | let right = build_semi_anti_right_table(); |
2619 | | |
2620 | | // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 |
2621 | | let on = vec![( |
2622 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2623 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2624 | | )]; |
2625 | | |
2626 | | let column_indices = vec![ColumnIndex { |
2627 | | index: 0, |
2628 | | side: JoinSide::Left, |
2629 | | }]; |
2630 | | let intermediate_schema = |
2631 | | Schema::new(vec![Field::new("x", DataType::Int32, true)]); |
2632 | | |
2633 | | let filter_expression = Arc::new(BinaryExpr::new( |
2634 | | Arc::new(Column::new("x", 0)), |
2635 | | Operator::NotEq, |
2636 | | Arc::new(Literal::new(ScalarValue::Int32(Some(9)))), |
2637 | | )) as Arc<dyn PhysicalExpr>; |
2638 | | |
2639 | | let filter = JoinFilter::new( |
2640 | | filter_expression, |
2641 | | column_indices.clone(), |
2642 | | intermediate_schema.clone(), |
2643 | | ); |
2644 | | |
2645 | | let join = join_with_filter( |
2646 | | Arc::clone(&left), |
2647 | | Arc::clone(&right), |
2648 | | on.clone(), |
2649 | | filter, |
2650 | | &JoinType::RightSemi, |
2651 | | false, |
2652 | | )?; |
2653 | | |
2654 | | let columns = columns(&join.schema()); |
2655 | | assert_eq!(columns, vec!["a2", "b2", "c2"]); |
2656 | | |
2657 | | let stream = join.execute(0, Arc::clone(&task_ctx))?; |
2658 | | let batches = common::collect(stream).await?; |
2659 | | |
2660 | | let expected = [ |
2661 | | "+----+----+-----+", |
2662 | | "| a2 | b2 | c2 |", |
2663 | | "+----+----+-----+", |
2664 | | "| 8 | 8 | 20 |", |
2665 | | "| 12 | 10 | 40 |", |
2666 | | "| 10 | 10 | 100 |", |
2667 | | "+----+----+-----+", |
2668 | | ]; |
2669 | | |
2670 | | // RightSemi join output is expected to preserve right input order |
2671 | | assert_batches_eq!(expected, &batches); |
2672 | | |
2673 | | // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 |
2674 | | let filter_expression = Arc::new(BinaryExpr::new( |
2675 | | Arc::new(Column::new("x", 0)), |
2676 | | Operator::Gt, |
2677 | | Arc::new(Literal::new(ScalarValue::Int32(Some(11)))), |
2678 | | )) as Arc<dyn PhysicalExpr>; |
2679 | | |
2680 | | let filter = |
2681 | | JoinFilter::new(filter_expression, column_indices, intermediate_schema); |
2682 | | |
2683 | | let join = |
2684 | | join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?; |
2685 | | let stream = join.execute(0, task_ctx)?; |
2686 | | let batches = common::collect(stream).await?; |
2687 | | |
2688 | | let expected = [ |
2689 | | "+----+----+-----+", |
2690 | | "| a2 | b2 | c2 |", |
2691 | | "+----+----+-----+", |
2692 | | "| 12 | 10 | 40 |", |
2693 | | "| 10 | 10 | 100 |", |
2694 | | "+----+----+-----+", |
2695 | | ]; |
2696 | | |
2697 | | // RightSemi join output is expected to preserve right input order |
2698 | | assert_batches_eq!(expected, &batches); |
2699 | | |
2700 | | Ok(()) |
2701 | | } |
2702 | | |
2703 | | #[apply(batch_sizes)] |
2704 | | #[tokio::test] |
2705 | | async fn join_left_anti(batch_size: usize) -> Result<()> { |
2706 | | let task_ctx = prepare_task_ctx(batch_size); |
2707 | | let left = build_semi_anti_left_table(); |
2708 | | let right = build_semi_anti_right_table(); |
2709 | | // left_table left anti join right_table on left_table.b1 = right_table.b2 |
2710 | | let on = vec![( |
2711 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2712 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2713 | | )]; |
2714 | | |
2715 | | let join = join(left, right, on, &JoinType::LeftAnti, false)?; |
2716 | | |
2717 | | let columns = columns(&join.schema()); |
2718 | | assert_eq!(columns, vec!["a1", "b1", "c1"]); |
2719 | | |
2720 | | let stream = join.execute(0, task_ctx)?; |
2721 | | let batches = common::collect(stream).await?; |
2722 | | |
2723 | | let expected = [ |
2724 | | "+----+----+----+", |
2725 | | "| a1 | b1 | c1 |", |
2726 | | "+----+----+----+", |
2727 | | "| 1 | 1 | 10 |", |
2728 | | "| 3 | 3 | 30 |", |
2729 | | "| 5 | 5 | 50 |", |
2730 | | "| 7 | 7 | 70 |", |
2731 | | "+----+----+----+", |
2732 | | ]; |
2733 | | assert_batches_sorted_eq!(expected, &batches); |
2734 | | Ok(()) |
2735 | | } |
2736 | | |
2737 | | #[apply(batch_sizes)] |
2738 | | #[tokio::test] |
2739 | | async fn join_left_anti_with_filter(batch_size: usize) -> Result<()> { |
2740 | | let task_ctx = prepare_task_ctx(batch_size); |
2741 | | let left = build_semi_anti_left_table(); |
2742 | | let right = build_semi_anti_right_table(); |
2743 | | // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 |
2744 | | let on = vec![( |
2745 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2746 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2747 | | )]; |
2748 | | |
2749 | | let column_indices = vec![ColumnIndex { |
2750 | | index: 0, |
2751 | | side: JoinSide::Right, |
2752 | | }]; |
2753 | | let intermediate_schema = |
2754 | | Schema::new(vec![Field::new("x", DataType::Int32, true)]); |
2755 | | let filter_expression = Arc::new(BinaryExpr::new( |
2756 | | Arc::new(Column::new("x", 0)), |
2757 | | Operator::NotEq, |
2758 | | Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), |
2759 | | )) as Arc<dyn PhysicalExpr>; |
2760 | | |
2761 | | let filter = JoinFilter::new( |
2762 | | filter_expression, |
2763 | | column_indices.clone(), |
2764 | | intermediate_schema.clone(), |
2765 | | ); |
2766 | | |
2767 | | let join = join_with_filter( |
2768 | | Arc::clone(&left), |
2769 | | Arc::clone(&right), |
2770 | | on.clone(), |
2771 | | filter, |
2772 | | &JoinType::LeftAnti, |
2773 | | false, |
2774 | | )?; |
2775 | | |
2776 | | let columns_header = columns(&join.schema()); |
2777 | | assert_eq!(columns_header, vec!["a1", "b1", "c1"]); |
2778 | | |
2779 | | let stream = join.execute(0, Arc::clone(&task_ctx))?; |
2780 | | let batches = common::collect(stream).await?; |
2781 | | |
2782 | | let expected = [ |
2783 | | "+----+----+-----+", |
2784 | | "| a1 | b1 | c1 |", |
2785 | | "+----+----+-----+", |
2786 | | "| 1 | 1 | 10 |", |
2787 | | "| 11 | 8 | 110 |", |
2788 | | "| 3 | 3 | 30 |", |
2789 | | "| 5 | 5 | 50 |", |
2790 | | "| 7 | 7 | 70 |", |
2791 | | "| 9 | 8 | 90 |", |
2792 | | "+----+----+-----+", |
2793 | | ]; |
2794 | | assert_batches_sorted_eq!(expected, &batches); |
2795 | | |
2796 | | // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 13 |
2797 | | let filter_expression = Arc::new(BinaryExpr::new( |
2798 | | Arc::new(Column::new("x", 0)), |
2799 | | Operator::NotEq, |
2800 | | Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), |
2801 | | )) as Arc<dyn PhysicalExpr>; |
2802 | | |
2803 | | let filter = |
2804 | | JoinFilter::new(filter_expression, column_indices, intermediate_schema); |
2805 | | |
2806 | | let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?; |
2807 | | |
2808 | | let columns_header = columns(&join.schema()); |
2809 | | assert_eq!(columns_header, vec!["a1", "b1", "c1"]); |
2810 | | |
2811 | | let stream = join.execute(0, task_ctx)?; |
2812 | | let batches = common::collect(stream).await?; |
2813 | | |
2814 | | let expected = [ |
2815 | | "+----+----+-----+", |
2816 | | "| a1 | b1 | c1 |", |
2817 | | "+----+----+-----+", |
2818 | | "| 1 | 1 | 10 |", |
2819 | | "| 11 | 8 | 110 |", |
2820 | | "| 3 | 3 | 30 |", |
2821 | | "| 5 | 5 | 50 |", |
2822 | | "| 7 | 7 | 70 |", |
2823 | | "| 9 | 8 | 90 |", |
2824 | | "+----+----+-----+", |
2825 | | ]; |
2826 | | assert_batches_sorted_eq!(expected, &batches); |
2827 | | |
2828 | | Ok(()) |
2829 | | } |
2830 | | |
2831 | | #[apply(batch_sizes)] |
2832 | | #[tokio::test] |
2833 | | async fn join_right_anti(batch_size: usize) -> Result<()> { |
2834 | | let task_ctx = prepare_task_ctx(batch_size); |
2835 | | let left = build_semi_anti_left_table(); |
2836 | | let right = build_semi_anti_right_table(); |
2837 | | let on = vec![( |
2838 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2839 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2840 | | )]; |
2841 | | |
2842 | | let join = join(left, right, on, &JoinType::RightAnti, false)?; |
2843 | | |
2844 | | let columns = columns(&join.schema()); |
2845 | | assert_eq!(columns, vec!["a2", "b2", "c2"]); |
2846 | | |
2847 | | let stream = join.execute(0, task_ctx)?; |
2848 | | let batches = common::collect(stream).await?; |
2849 | | |
2850 | | let expected = [ |
2851 | | "+----+----+-----+", |
2852 | | "| a2 | b2 | c2 |", |
2853 | | "+----+----+-----+", |
2854 | | "| 6 | 6 | 60 |", |
2855 | | "| 2 | 2 | 80 |", |
2856 | | "| 4 | 4 | 120 |", |
2857 | | "+----+----+-----+", |
2858 | | ]; |
2859 | | |
2860 | | // RightAnti join output is expected to preserve right input order |
2861 | | assert_batches_eq!(expected, &batches); |
2862 | | Ok(()) |
2863 | | } |
2864 | | |
2865 | | #[apply(batch_sizes)] |
2866 | | #[tokio::test] |
2867 | | async fn join_right_anti_with_filter(batch_size: usize) -> Result<()> { |
2868 | | let task_ctx = prepare_task_ctx(batch_size); |
2869 | | let left = build_semi_anti_left_table(); |
2870 | | let right = build_semi_anti_right_table(); |
2871 | | // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 |
2872 | | let on = vec![( |
2873 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2874 | | Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, |
2875 | | )]; |
2876 | | |
2877 | | let column_indices = vec![ColumnIndex { |
2878 | | index: 0, |
2879 | | side: JoinSide::Left, |
2880 | | }]; |
2881 | | let intermediate_schema = |
2882 | | Schema::new(vec![Field::new("x", DataType::Int32, true)]); |
2883 | | |
2884 | | let filter_expression = Arc::new(BinaryExpr::new( |
2885 | | Arc::new(Column::new("x", 0)), |
2886 | | Operator::NotEq, |
2887 | | Arc::new(Literal::new(ScalarValue::Int32(Some(13)))), |
2888 | | )) as Arc<dyn PhysicalExpr>; |
2889 | | |
2890 | | let filter = JoinFilter::new( |
2891 | | filter_expression, |
2892 | | column_indices, |
2893 | | intermediate_schema.clone(), |
2894 | | ); |
2895 | | |
2896 | | let join = join_with_filter( |
2897 | | Arc::clone(&left), |
2898 | | Arc::clone(&right), |
2899 | | on.clone(), |
2900 | | filter, |
2901 | | &JoinType::RightAnti, |
2902 | | false, |
2903 | | )?; |
2904 | | |
2905 | | let columns_header = columns(&join.schema()); |
2906 | | assert_eq!(columns_header, vec!["a2", "b2", "c2"]); |
2907 | | |
2908 | | let stream = join.execute(0, Arc::clone(&task_ctx))?; |
2909 | | let batches = common::collect(stream).await?; |
2910 | | |
2911 | | let expected = [ |
2912 | | "+----+----+-----+", |
2913 | | "| a2 | b2 | c2 |", |
2914 | | "+----+----+-----+", |
2915 | | "| 12 | 10 | 40 |", |
2916 | | "| 6 | 6 | 60 |", |
2917 | | "| 2 | 2 | 80 |", |
2918 | | "| 10 | 10 | 100 |", |
2919 | | "| 4 | 4 | 120 |", |
2920 | | "+----+----+-----+", |
2921 | | ]; |
2922 | | |
2923 | | // RightAnti join output is expected to preserve right input order |
2924 | | assert_batches_eq!(expected, &batches); |
2925 | | |
2926 | | // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 |
2927 | | let column_indices = vec![ColumnIndex { |
2928 | | index: 1, |
2929 | | side: JoinSide::Right, |
2930 | | }]; |
2931 | | let filter_expression = Arc::new(BinaryExpr::new( |
2932 | | Arc::new(Column::new("x", 0)), |
2933 | | Operator::NotEq, |
2934 | | Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), |
2935 | | )) as Arc<dyn PhysicalExpr>; |
2936 | | |
2937 | | let filter = |
2938 | | JoinFilter::new(filter_expression, column_indices, intermediate_schema); |
2939 | | |
2940 | | let join = |
2941 | | join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?; |
2942 | | |
2943 | | let columns_header = columns(&join.schema()); |
2944 | | assert_eq!(columns_header, vec!["a2", "b2", "c2"]); |
2945 | | |
2946 | | let stream = join.execute(0, task_ctx)?; |
2947 | | let batches = common::collect(stream).await?; |
2948 | | |
2949 | | let expected = [ |
2950 | | "+----+----+-----+", |
2951 | | "| a2 | b2 | c2 |", |
2952 | | "+----+----+-----+", |
2953 | | "| 8 | 8 | 20 |", |
2954 | | "| 6 | 6 | 60 |", |
2955 | | "| 2 | 2 | 80 |", |
2956 | | "| 4 | 4 | 120 |", |
2957 | | "+----+----+-----+", |
2958 | | ]; |
2959 | | |
2960 | | // RightAnti join output is expected to preserve right input order |
2961 | | assert_batches_eq!(expected, &batches); |
2962 | | |
2963 | | Ok(()) |
2964 | | } |
2965 | | |
2966 | | #[apply(batch_sizes)] |
2967 | | #[tokio::test] |
2968 | | async fn join_right_one(batch_size: usize) -> Result<()> { |
2969 | | let task_ctx = prepare_task_ctx(batch_size); |
2970 | | let left = build_table( |
2971 | | ("a1", &vec![1, 2, 3]), |
2972 | | ("b1", &vec![4, 5, 7]), |
2973 | | ("c1", &vec![7, 8, 9]), |
2974 | | ); |
2975 | | let right = build_table( |
2976 | | ("a2", &vec![10, 20, 30]), |
2977 | | ("b1", &vec![4, 5, 6]), // 6 does not exist on the left |
2978 | | ("c2", &vec![70, 80, 90]), |
2979 | | ); |
2980 | | let on = vec![( |
2981 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
2982 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
2983 | | )]; |
2984 | | |
2985 | | let (columns, batches) = |
2986 | | join_collect(left, right, on, &JoinType::Right, false, task_ctx).await?; |
2987 | | |
2988 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
2989 | | |
2990 | | let expected = [ |
2991 | | "+----+----+----+----+----+----+", |
2992 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2993 | | "+----+----+----+----+----+----+", |
2994 | | "| | | | 30 | 6 | 90 |", |
2995 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2996 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2997 | | "+----+----+----+----+----+----+", |
2998 | | ]; |
2999 | | |
3000 | | assert_batches_sorted_eq!(expected, &batches); |
3001 | | |
3002 | | Ok(()) |
3003 | | } |
3004 | | |
3005 | | #[apply(batch_sizes)] |
3006 | | #[tokio::test] |
3007 | | async fn partitioned_join_right_one(batch_size: usize) -> Result<()> { |
3008 | | let task_ctx = prepare_task_ctx(batch_size); |
3009 | | let left = build_table( |
3010 | | ("a1", &vec![1, 2, 3]), |
3011 | | ("b1", &vec![4, 5, 7]), |
3012 | | ("c1", &vec![7, 8, 9]), |
3013 | | ); |
3014 | | let right = build_table( |
3015 | | ("a2", &vec![10, 20, 30]), |
3016 | | ("b1", &vec![4, 5, 6]), // 6 does not exist on the left |
3017 | | ("c2", &vec![70, 80, 90]), |
3018 | | ); |
3019 | | let on = vec![( |
3020 | | Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, |
3021 | | Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, |
3022 | | )]; |
3023 | | |
3024 | | let (columns, batches) = |
3025 | | partitioned_join_collect(left, right, on, &JoinType::Right, false, task_ctx) |
3026 | | .await?; |
3027 | | |
3028 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); |
3029 | | |
3030 | | let expected = [ |
3031 | | "+----+----+----+----+----+----+", |
3032 | | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
3033 | | "+----+----+----+----+----+----+", |
3034 | | "| | | | 30 | 6 | 90 |", |
3035 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
3036 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
3037 | | "+----+----+----+----+----+----+", |
3038 | | ]; |
3039 | | |
3040 | | assert_batches_sorted_eq!(expected, &batches); |
3041 | | |
3042 | | Ok(()) |
3043 | | } |
3044 | | |
3045 | | #[apply(batch_sizes)] |
3046 | | #[tokio::test] |
3047 | | async fn join_full_one(batch_size: usize) -> Result<()> { |
3048 | | let task_ctx = prepare_task_ctx(batch_size); |
3049 | | let left = build_table( |
3050 | | ("a1", &vec![1, 2, 3]), |
3051 | | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
3052 | | ("c1", &vec![7, 8, 9]), |
3053 | | ); |
3054 | | let right = build_table( |
3055 | | ("a2", &vec![10, 20, 30]), |
3056 | | ("b2", &vec![4, 5, 6]), |
3057 | | ("c2", &vec![70, 80, 90]), |
3058 | | ); |
3059 | | let on = vec![( |
3060 | | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
3061 | | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
3062 | | )]; |
3063 | | |
3064 | | let join = join(left, right, on, &JoinType::Full, false)?; |
3065 | | |
3066 | | let columns = columns(&join.schema()); |
3067 | | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
3068 | | |
3069 | | let stream = join.execute(0, task_ctx)?; |
3070 | | let batches = common::collect(stream).await?; |
3071 | | |
3072 | | let expected = [ |
3073 | | "+----+----+----+----+----+----+", |
3074 | | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
3075 | | "+----+----+----+----+----+----+", |
3076 | | "| | | | 30 | 6 | 90 |", |
3077 | | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
3078 | | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
3079 | | "| 3 | 7 | 9 | | | |", |
3080 | | "+----+----+----+----+----+----+", |
3081 | | ]; |
3082 | | assert_batches_sorted_eq!(expected, &batches); |
3083 | | |
3084 | | Ok(()) |
3085 | | } |
3086 | | |
3087 | | #[test] |
3088 | 1 | fn join_with_hash_collision() -> Result<()> { |
3089 | 1 | let mut hashmap_left = RawTable::with_capacity(2); |
3090 | 1 | let left = build_table_i32( |
3091 | 1 | ("a", &vec![10, 20]), |
3092 | 1 | ("x", &vec![100, 200]), |
3093 | 1 | ("y", &vec![200, 300]), |
3094 | 1 | ); |
3095 | 1 | |
3096 | 1 | let random_state = RandomState::with_seeds(0, 0, 0, 0); |
3097 | 1 | let hashes_buff = &mut vec![0; left.num_rows()]; |
3098 | 1 | let hashes = create_hashes( |
3099 | 1 | &[Arc::clone(&left.columns()[0])], |
3100 | 1 | &random_state, |
3101 | 1 | hashes_buff, |
3102 | 1 | )?0 ; |
3103 | | |
3104 | | // Create hash collisions (same hashes) |
3105 | 1 | hashmap_left.insert(hashes[0], (hashes[0], 1), |(h, _)| *h0 ); |
3106 | 1 | hashmap_left.insert(hashes[1], (hashes[1], 1), |(h, _)| *h0 ); |
3107 | 1 | |
3108 | 1 | let next = vec![2, 0]; |
3109 | 1 | |
3110 | 1 | let right = build_table_i32( |
3111 | 1 | ("a", &vec![10, 20]), |
3112 | 1 | ("b", &vec![0, 0]), |
3113 | 1 | ("c", &vec![30, 40]), |
3114 | 1 | ); |
3115 | 1 | |
3116 | 1 | // Join key column for both join sides |
3117 | 1 | let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _; |
3118 | 1 | |
3119 | 1 | let join_hash_map = JoinHashMap::new(hashmap_left, next); |
3120 | | |
3121 | 1 | let right_keys_values = |
3122 | 1 | key_column.evaluate(&right)?0 .into_array(right.num_rows())?0 ; |
3123 | 1 | let mut hashes_buffer = vec![0; right.num_rows()]; |
3124 | 1 | create_hashes(&[right_keys_values], &random_state, &mut hashes_buffer)?0 ; |
3125 | | |
3126 | 1 | let (l, r, _) = lookup_join_hashmap( |
3127 | 1 | &join_hash_map, |
3128 | 1 | &left, |
3129 | 1 | &right, |
3130 | 1 | &[Arc::clone(&key_column)], |
3131 | 1 | &[key_column], |
3132 | 1 | false, |
3133 | 1 | &hashes_buffer, |
3134 | 1 | 8192, |
3135 | 1 | (0, None), |
3136 | 1 | )?0 ; |
3137 | | |
3138 | 1 | let left_ids: UInt64Array = vec![0, 1].into(); |
3139 | 1 | |
3140 | 1 | let right_ids: UInt32Array = vec![0, 1].into(); |
3141 | 1 | |
3142 | 1 | assert_eq!(left_ids, l); |
3143 | | |
3144 | 1 | assert_eq!(right_ids, r); |
3145 | | |
3146 | 1 | Ok(()) |
3147 | 1 | } |
3148 | | |
3149 | | #[tokio::test] |
3150 | 1 | async fn join_with_duplicated_column_names() -> Result<()> { |
3151 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
3152 | 1 | let left = build_table( |
3153 | 1 | ("a", &vec![1, 2, 3]), |
3154 | 1 | ("b", &vec![4, 5, 7]), |
3155 | 1 | ("c", &vec![7, 8, 9]), |
3156 | 1 | ); |
3157 | 1 | let right = build_table( |
3158 | 1 | ("a", &vec![10, 20, 30]), |
3159 | 1 | ("b", &vec![1, 2, 7]), |
3160 | 1 | ("c", &vec![70, 80, 90]), |
3161 | 1 | ); |
3162 | 1 | let on = vec![( |
3163 | 1 | // join on a=b so there are duplicate column names on unjoined columns |
3164 | 1 | Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, |
3165 | 1 | Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, |
3166 | 1 | )]; |
3167 | 1 | |
3168 | 1 | let join = join(left, right, on, &JoinType::Inner, false)?0 ; |
3169 | 1 | |
3170 | 1 | let columns = columns(&join.schema()); |
3171 | 1 | assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); |
3172 | 1 | |
3173 | 1 | let stream = join.execute(0, task_ctx)?0 ; |
3174 | 1 | let batches = common::collect(stream).await0 ?0 ; |
3175 | 1 | |
3176 | 1 | let expected = [ |
3177 | 1 | "+---+---+---+----+---+----+", |
3178 | 1 | "| a | b | c | a | b | c |", |
3179 | 1 | "+---+---+---+----+---+----+", |
3180 | 1 | "| 1 | 4 | 7 | 10 | 1 | 70 |", |
3181 | 1 | "| 2 | 5 | 8 | 20 | 2 | 80 |", |
3182 | 1 | "+---+---+---+----+---+----+", |
3183 | 1 | ]; |
3184 | 1 | assert_batches_sorted_eq!(expected, &batches); |
3185 | 1 | |
3186 | 1 | Ok(()) |
3187 | 1 | } |
3188 | | |
3189 | 20 | fn prepare_join_filter() -> JoinFilter { |
3190 | 20 | let column_indices = vec![ |
3191 | 20 | ColumnIndex { |
3192 | 20 | index: 2, |
3193 | 20 | side: JoinSide::Left, |
3194 | 20 | }, |
3195 | 20 | ColumnIndex { |
3196 | 20 | index: 2, |
3197 | 20 | side: JoinSide::Right, |
3198 | 20 | }, |
3199 | 20 | ]; |
3200 | 20 | let intermediate_schema = Schema::new(vec![ |
3201 | 20 | Field::new("c", DataType::Int32, true), |
3202 | 20 | Field::new("c", DataType::Int32, true), |
3203 | 20 | ]); |
3204 | 20 | let filter_expression = Arc::new(BinaryExpr::new( |
3205 | 20 | Arc::new(Column::new("c", 0)), |
3206 | 20 | Operator::Gt, |
3207 | 20 | Arc::new(Column::new("c", 1)), |
3208 | 20 | )) as Arc<dyn PhysicalExpr>; |
3209 | 20 | |
3210 | 20 | JoinFilter::new(filter_expression, column_indices, intermediate_schema) |
3211 | 20 | } |
3212 | | |
3213 | | #[apply(batch_sizes)] |
3214 | | #[tokio::test] |
3215 | | async fn join_inner_with_filter(batch_size: usize) -> Result<()> { |
3216 | | let task_ctx = prepare_task_ctx(batch_size); |
3217 | | let left = build_table( |
3218 | | ("a", &vec![0, 1, 2, 2]), |
3219 | | ("b", &vec![4, 5, 7, 8]), |
3220 | | ("c", &vec![7, 8, 9, 1]), |
3221 | | ); |
3222 | | let right = build_table( |
3223 | | ("a", &vec![10, 20, 30, 40]), |
3224 | | ("b", &vec![2, 2, 3, 4]), |
3225 | | ("c", &vec![7, 5, 6, 4]), |
3226 | | ); |
3227 | | let on = vec![( |
3228 | | Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, |
3229 | | Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, |
3230 | | )]; |
3231 | | let filter = prepare_join_filter(); |
3232 | | |
3233 | | let join = join_with_filter(left, right, on, filter, &JoinType::Inner, false)?; |
3234 | | |
3235 | | let columns = columns(&join.schema()); |
3236 | | assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); |
3237 | | |
3238 | | let stream = join.execute(0, task_ctx)?; |
3239 | | let batches = common::collect(stream).await?; |
3240 | | |
3241 | | let expected = [ |
3242 | | "+---+---+---+----+---+---+", |
3243 | | "| a | b | c | a | b | c |", |
3244 | | "+---+---+---+----+---+---+", |
3245 | | "| 2 | 7 | 9 | 10 | 2 | 7 |", |
3246 | | "| 2 | 7 | 9 | 20 | 2 | 5 |", |
3247 | | "+---+---+---+----+---+---+", |
3248 | | ]; |
3249 | | assert_batches_sorted_eq!(expected, &batches); |
3250 | | |
3251 | | Ok(()) |
3252 | | } |
3253 | | |
3254 | | #[apply(batch_sizes)] |
3255 | | #[tokio::test] |
3256 | | async fn join_left_with_filter(batch_size: usize) -> Result<()> { |
3257 | | let task_ctx = prepare_task_ctx(batch_size); |
3258 | | let left = build_table( |
3259 | | ("a", &vec![0, 1, 2, 2]), |
3260 | | ("b", &vec![4, 5, 7, 8]), |
3261 | | ("c", &vec![7, 8, 9, 1]), |
3262 | | ); |
3263 | | let right = build_table( |
3264 | | ("a", &vec![10, 20, 30, 40]), |
3265 | | ("b", &vec![2, 2, 3, 4]), |
3266 | | ("c", &vec![7, 5, 6, 4]), |
3267 | | ); |
3268 | | let on = vec![( |
3269 | | Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, |
3270 | | Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, |
3271 | | )]; |
3272 | | let filter = prepare_join_filter(); |
3273 | | |
3274 | | let join = join_with_filter(left, right, on, filter, &JoinType::Left, false)?; |
3275 | | |
3276 | | let columns = columns(&join.schema()); |
3277 | | assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); |
3278 | | |
3279 | | let stream = join.execute(0, task_ctx)?; |
3280 | | let batches = common::collect(stream).await?; |
3281 | | |
3282 | | let expected = [ |
3283 | | "+---+---+---+----+---+---+", |
3284 | | "| a | b | c | a | b | c |", |
3285 | | "+---+---+---+----+---+---+", |
3286 | | "| 0 | 4 | 7 | | | |", |
3287 | | "| 1 | 5 | 8 | | | |", |
3288 | | "| 2 | 7 | 9 | 10 | 2 | 7 |", |
3289 | | "| 2 | 7 | 9 | 20 | 2 | 5 |", |
3290 | | "| 2 | 8 | 1 | | | |", |
3291 | | "+---+---+---+----+---+---+", |
3292 | | ]; |
3293 | | assert_batches_sorted_eq!(expected, &batches); |
3294 | | |
3295 | | Ok(()) |
3296 | | } |
3297 | | |
3298 | | #[apply(batch_sizes)] |
3299 | | #[tokio::test] |
3300 | | async fn join_right_with_filter(batch_size: usize) -> Result<()> { |
3301 | | let task_ctx = prepare_task_ctx(batch_size); |
3302 | | let left = build_table( |
3303 | | ("a", &vec![0, 1, 2, 2]), |
3304 | | ("b", &vec![4, 5, 7, 8]), |
3305 | | ("c", &vec![7, 8, 9, 1]), |
3306 | | ); |
3307 | | let right = build_table( |
3308 | | ("a", &vec![10, 20, 30, 40]), |
3309 | | ("b", &vec![2, 2, 3, 4]), |
3310 | | ("c", &vec![7, 5, 6, 4]), |
3311 | | ); |
3312 | | let on = vec![( |
3313 | | Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, |
3314 | | Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, |
3315 | | )]; |
3316 | | let filter = prepare_join_filter(); |
3317 | | |
3318 | | let join = join_with_filter(left, right, on, filter, &JoinType::Right, false)?; |
3319 | | |
3320 | | let columns = columns(&join.schema()); |
3321 | | assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); |
3322 | | |
3323 | | let stream = join.execute(0, task_ctx)?; |
3324 | | let batches = common::collect(stream).await?; |
3325 | | |
3326 | | let expected = [ |
3327 | | "+---+---+---+----+---+---+", |
3328 | | "| a | b | c | a | b | c |", |
3329 | | "+---+---+---+----+---+---+", |
3330 | | "| | | | 30 | 3 | 6 |", |
3331 | | "| | | | 40 | 4 | 4 |", |
3332 | | "| 2 | 7 | 9 | 10 | 2 | 7 |", |
3333 | | "| 2 | 7 | 9 | 20 | 2 | 5 |", |
3334 | | "+---+---+---+----+---+---+", |
3335 | | ]; |
3336 | | assert_batches_sorted_eq!(expected, &batches); |
3337 | | |
3338 | | Ok(()) |
3339 | | } |
3340 | | |
3341 | | #[apply(batch_sizes)] |
3342 | | #[tokio::test] |
3343 | | async fn join_full_with_filter(batch_size: usize) -> Result<()> { |
3344 | | let task_ctx = prepare_task_ctx(batch_size); |
3345 | | let left = build_table( |
3346 | | ("a", &vec![0, 1, 2, 2]), |
3347 | | ("b", &vec![4, 5, 7, 8]), |
3348 | | ("c", &vec![7, 8, 9, 1]), |
3349 | | ); |
3350 | | let right = build_table( |
3351 | | ("a", &vec![10, 20, 30, 40]), |
3352 | | ("b", &vec![2, 2, 3, 4]), |
3353 | | ("c", &vec![7, 5, 6, 4]), |
3354 | | ); |
3355 | | let on = vec![( |
3356 | | Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, |
3357 | | Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, |
3358 | | )]; |
3359 | | let filter = prepare_join_filter(); |
3360 | | |
3361 | | let join = join_with_filter(left, right, on, filter, &JoinType::Full, false)?; |
3362 | | |
3363 | | let columns = columns(&join.schema()); |
3364 | | assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); |
3365 | | |
3366 | | let stream = join.execute(0, task_ctx)?; |
3367 | | let batches = common::collect(stream).await?; |
3368 | | |
3369 | | let expected = [ |
3370 | | "+---+---+---+----+---+---+", |
3371 | | "| a | b | c | a | b | c |", |
3372 | | "+---+---+---+----+---+---+", |
3373 | | "| | | | 30 | 3 | 6 |", |
3374 | | "| | | | 40 | 4 | 4 |", |
3375 | | "| 2 | 7 | 9 | 10 | 2 | 7 |", |
3376 | | "| 2 | 7 | 9 | 20 | 2 | 5 |", |
3377 | | "| 0 | 4 | 7 | | | |", |
3378 | | "| 1 | 5 | 8 | | | |", |
3379 | | "| 2 | 8 | 1 | | | |", |
3380 | | "+---+---+---+----+---+---+", |
3381 | | ]; |
3382 | | assert_batches_sorted_eq!(expected, &batches); |
3383 | | |
3384 | | Ok(()) |
3385 | | } |
3386 | | |
3387 | | /// Test for parallelised HashJoinExec with PartitionMode::CollectLeft |
3388 | | #[tokio::test] |
3389 | 1 | async fn test_collect_left_multiple_partitions_join() -> Result<()> { |
3390 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
3391 | 1 | let left = build_table( |
3392 | 1 | ("a1", &vec![1, 2, 3]), |
3393 | 1 | ("b1", &vec![4, 5, 7]), |
3394 | 1 | ("c1", &vec![7, 8, 9]), |
3395 | 1 | ); |
3396 | 1 | let right = build_table( |
3397 | 1 | ("a2", &vec![10, 20, 30]), |
3398 | 1 | ("b2", &vec![4, 5, 6]), |
3399 | 1 | ("c2", &vec![70, 80, 90]), |
3400 | 1 | ); |
3401 | 1 | let on = vec![( |
3402 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
3403 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
3404 | 1 | )]; |
3405 | 1 | |
3406 | 1 | let expected_inner = vec![ |
3407 | 1 | "+----+----+----+----+----+----+", |
3408 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
3409 | 1 | "+----+----+----+----+----+----+", |
3410 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
3411 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
3412 | 1 | "+----+----+----+----+----+----+", |
3413 | 1 | ]; |
3414 | 1 | let expected_left = vec![ |
3415 | 1 | "+----+----+----+----+----+----+", |
3416 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
3417 | 1 | "+----+----+----+----+----+----+", |
3418 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
3419 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
3420 | 1 | "| 3 | 7 | 9 | | | |", |
3421 | 1 | "+----+----+----+----+----+----+", |
3422 | 1 | ]; |
3423 | 1 | let expected_right = vec![ |
3424 | 1 | "+----+----+----+----+----+----+", |
3425 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
3426 | 1 | "+----+----+----+----+----+----+", |
3427 | 1 | "| | | | 30 | 6 | 90 |", |
3428 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
3429 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
3430 | 1 | "+----+----+----+----+----+----+", |
3431 | 1 | ]; |
3432 | 1 | let expected_full = vec![ |
3433 | 1 | "+----+----+----+----+----+----+", |
3434 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
3435 | 1 | "+----+----+----+----+----+----+", |
3436 | 1 | "| | | | 30 | 6 | 90 |", |
3437 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
3438 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
3439 | 1 | "| 3 | 7 | 9 | | | |", |
3440 | 1 | "+----+----+----+----+----+----+", |
3441 | 1 | ]; |
3442 | 1 | let expected_left_semi = vec![ |
3443 | 1 | "+----+----+----+", |
3444 | 1 | "| a1 | b1 | c1 |", |
3445 | 1 | "+----+----+----+", |
3446 | 1 | "| 1 | 4 | 7 |", |
3447 | 1 | "| 2 | 5 | 8 |", |
3448 | 1 | "+----+----+----+", |
3449 | 1 | ]; |
3450 | 1 | let expected_left_anti = vec![ |
3451 | 1 | "+----+----+----+", |
3452 | 1 | "| a1 | b1 | c1 |", |
3453 | 1 | "+----+----+----+", |
3454 | 1 | "| 3 | 7 | 9 |", |
3455 | 1 | "+----+----+----+", |
3456 | 1 | ]; |
3457 | 1 | let expected_right_semi = vec![ |
3458 | 1 | "+----+----+----+", |
3459 | 1 | "| a2 | b2 | c2 |", |
3460 | 1 | "+----+----+----+", |
3461 | 1 | "| 10 | 4 | 70 |", |
3462 | 1 | "| 20 | 5 | 80 |", |
3463 | 1 | "+----+----+----+", |
3464 | 1 | ]; |
3465 | 1 | let expected_right_anti = vec![ |
3466 | 1 | "+----+----+----+", |
3467 | 1 | "| a2 | b2 | c2 |", |
3468 | 1 | "+----+----+----+", |
3469 | 1 | "| 30 | 6 | 90 |", |
3470 | 1 | "+----+----+----+", |
3471 | 1 | ]; |
3472 | 1 | |
3473 | 1 | let test_cases = vec![ |
3474 | 1 | (JoinType::Inner, expected_inner), |
3475 | 1 | (JoinType::Left, expected_left), |
3476 | 1 | (JoinType::Right, expected_right), |
3477 | 1 | (JoinType::Full, expected_full), |
3478 | 1 | (JoinType::LeftSemi, expected_left_semi), |
3479 | 1 | (JoinType::LeftAnti, expected_left_anti), |
3480 | 1 | (JoinType::RightSemi, expected_right_semi), |
3481 | 1 | (JoinType::RightAnti, expected_right_anti), |
3482 | 1 | ]; |
3483 | 1 | |
3484 | 9 | for (join_type, expected8 ) in test_cases { |
3485 | 8 | let (_, batches) = join_collect_with_partition_mode( |
3486 | 8 | Arc::clone(&left), |
3487 | 8 | Arc::clone(&right), |
3488 | 8 | on.clone(), |
3489 | 8 | &join_type, |
3490 | 8 | PartitionMode::CollectLeft, |
3491 | 8 | false, |
3492 | 8 | Arc::clone(&task_ctx), |
3493 | 8 | ) |
3494 | 24 | .await?0 ; |
3495 | 8 | assert_batches_sorted_eq!(expected, &batches); |
3496 | 1 | } |
3497 | 1 | |
3498 | 1 | Ok(()) |
3499 | 1 | } |
3500 | | |
3501 | | #[tokio::test] |
3502 | 1 | async fn join_date32() -> Result<()> { |
3503 | 1 | let schema = Arc::new(Schema::new(vec![ |
3504 | 1 | Field::new("date", DataType::Date32, false), |
3505 | 1 | Field::new("n", DataType::Int32, false), |
3506 | 1 | ])); |
3507 | 1 | |
3508 | 1 | let dates: ArrayRef = Arc::new(Date32Array::from(vec![19107, 19108, 19109])); |
3509 | 1 | let n: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); |
3510 | 1 | let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?0 ; |
3511 | 1 | let left = Arc::new( |
3512 | 1 | MemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None).unwrap(), |
3513 | 1 | ); |
3514 | 1 | |
3515 | 1 | let dates: ArrayRef = Arc::new(Date32Array::from(vec![19108, 19108, 19109])); |
3516 | 1 | let n: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); |
3517 | 1 | let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?0 ; |
3518 | 1 | let right = Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()); |
3519 | 1 | |
3520 | 1 | let on = vec![( |
3521 | 1 | Arc::new(Column::new_with_schema("date", &left.schema()).unwrap()) as _, |
3522 | 1 | Arc::new(Column::new_with_schema("date", &right.schema()).unwrap()) as _, |
3523 | 1 | )]; |
3524 | 1 | |
3525 | 1 | let join = join(left, right, on, &JoinType::Inner, false)?0 ; |
3526 | 1 | |
3527 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
3528 | 1 | let stream = join.execute(0, task_ctx)?0 ; |
3529 | 1 | let batches = common::collect(stream).await0 ?0 ; |
3530 | 1 | |
3531 | 1 | let expected = [ |
3532 | 1 | "+------------+---+------------+---+", |
3533 | 1 | "| date | n | date | n |", |
3534 | 1 | "+------------+---+------------+---+", |
3535 | 1 | "| 2022-04-26 | 2 | 2022-04-26 | 4 |", |
3536 | 1 | "| 2022-04-26 | 2 | 2022-04-26 | 5 |", |
3537 | 1 | "| 2022-04-27 | 3 | 2022-04-27 | 6 |", |
3538 | 1 | "+------------+---+------------+---+", |
3539 | 1 | ]; |
3540 | 1 | assert_batches_sorted_eq!(expected, &batches); |
3541 | 1 | |
3542 | 1 | Ok(()) |
3543 | 1 | } |
3544 | | |
3545 | | #[tokio::test] |
3546 | 1 | async fn join_with_error_right() { |
3547 | 1 | let left = build_table( |
3548 | 1 | ("a1", &vec![1, 2, 3]), |
3549 | 1 | ("b1", &vec![4, 5, 7]), |
3550 | 1 | ("c1", &vec![7, 8, 9]), |
3551 | 1 | ); |
3552 | 1 | |
3553 | 1 | // right input stream returns one good batch and then one error. |
3554 | 1 | // The error should be returned. |
3555 | 1 | let err = exec_err!("bad data error"); |
3556 | 1 | let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); |
3557 | 1 | |
3558 | 1 | let on = vec![( |
3559 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
3560 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, |
3561 | 1 | )]; |
3562 | 1 | let schema = right.schema(); |
3563 | 1 | let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); |
3564 | 1 | let right_input = Arc::new(MockExec::new(vec![Ok(right), err], schema)); |
3565 | 1 | |
3566 | 1 | let join_types = vec![ |
3567 | 1 | JoinType::Inner, |
3568 | 1 | JoinType::Left, |
3569 | 1 | JoinType::Right, |
3570 | 1 | JoinType::Full, |
3571 | 1 | JoinType::LeftSemi, |
3572 | 1 | JoinType::LeftAnti, |
3573 | 1 | JoinType::RightSemi, |
3574 | 1 | JoinType::RightAnti, |
3575 | 1 | ]; |
3576 | 1 | |
3577 | 9 | for join_type8 in join_types { |
3578 | 8 | let join = join( |
3579 | 8 | Arc::clone(&left), |
3580 | 8 | Arc::clone(&right_input) as Arc<dyn ExecutionPlan>, |
3581 | 8 | on.clone(), |
3582 | 8 | &join_type, |
3583 | 8 | false, |
3584 | 8 | ) |
3585 | 8 | .unwrap(); |
3586 | 8 | let task_ctx = Arc::new(TaskContext::default()); |
3587 | 8 | |
3588 | 8 | let stream = join.execute(0, task_ctx).unwrap(); |
3589 | 1 | |
3590 | 1 | // Expect that an error is returned |
3591 | 8 | let result_string = crate::common::collect(stream) |
3592 | 8 | .await |
3593 | 8 | .unwrap_err() |
3594 | 8 | .to_string(); |
3595 | 8 | assert!( |
3596 | 8 | result_string.contains("bad data error"), |
3597 | 1 | "actual: {result_string}"0 |
3598 | 1 | ); |
3599 | 1 | } |
3600 | 1 | } |
3601 | | |
3602 | | #[tokio::test] |
3603 | 1 | async fn join_splitted_batch() { |
3604 | 1 | let left = build_table( |
3605 | 1 | ("a1", &vec![1, 2, 3, 4]), |
3606 | 1 | ("b1", &vec![1, 1, 1, 1]), |
3607 | 1 | ("c1", &vec![0, 0, 0, 0]), |
3608 | 1 | ); |
3609 | 1 | let right = build_table( |
3610 | 1 | ("a2", &vec![10, 20, 30, 40, 50]), |
3611 | 1 | ("b2", &vec![1, 1, 1, 1, 1]), |
3612 | 1 | ("c2", &vec![0, 0, 0, 0, 0]), |
3613 | 1 | ); |
3614 | 1 | let on = vec![( |
3615 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
3616 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
3617 | 1 | )]; |
3618 | 1 | |
3619 | 1 | let join_types = vec![ |
3620 | 1 | JoinType::Inner, |
3621 | 1 | JoinType::Left, |
3622 | 1 | JoinType::Right, |
3623 | 1 | JoinType::Full, |
3624 | 1 | JoinType::RightSemi, |
3625 | 1 | JoinType::RightAnti, |
3626 | 1 | JoinType::LeftSemi, |
3627 | 1 | JoinType::LeftAnti, |
3628 | 1 | ]; |
3629 | 1 | let expected_resultset_records = 20; |
3630 | 1 | let common_result = [ |
3631 | 1 | "+----+----+----+----+----+----+", |
3632 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
3633 | 1 | "+----+----+----+----+----+----+", |
3634 | 1 | "| 1 | 1 | 0 | 10 | 1 | 0 |", |
3635 | 1 | "| 2 | 1 | 0 | 10 | 1 | 0 |", |
3636 | 1 | "| 3 | 1 | 0 | 10 | 1 | 0 |", |
3637 | 1 | "| 4 | 1 | 0 | 10 | 1 | 0 |", |
3638 | 1 | "| 1 | 1 | 0 | 20 | 1 | 0 |", |
3639 | 1 | "| 2 | 1 | 0 | 20 | 1 | 0 |", |
3640 | 1 | "| 3 | 1 | 0 | 20 | 1 | 0 |", |
3641 | 1 | "| 4 | 1 | 0 | 20 | 1 | 0 |", |
3642 | 1 | "| 1 | 1 | 0 | 30 | 1 | 0 |", |
3643 | 1 | "| 2 | 1 | 0 | 30 | 1 | 0 |", |
3644 | 1 | "| 3 | 1 | 0 | 30 | 1 | 0 |", |
3645 | 1 | "| 4 | 1 | 0 | 30 | 1 | 0 |", |
3646 | 1 | "| 1 | 1 | 0 | 40 | 1 | 0 |", |
3647 | 1 | "| 2 | 1 | 0 | 40 | 1 | 0 |", |
3648 | 1 | "| 3 | 1 | 0 | 40 | 1 | 0 |", |
3649 | 1 | "| 4 | 1 | 0 | 40 | 1 | 0 |", |
3650 | 1 | "| 1 | 1 | 0 | 50 | 1 | 0 |", |
3651 | 1 | "| 2 | 1 | 0 | 50 | 1 | 0 |", |
3652 | 1 | "| 3 | 1 | 0 | 50 | 1 | 0 |", |
3653 | 1 | "| 4 | 1 | 0 | 50 | 1 | 0 |", |
3654 | 1 | "+----+----+----+----+----+----+", |
3655 | 1 | ]; |
3656 | 1 | let left_batch = [ |
3657 | 1 | "+----+----+----+", |
3658 | 1 | "| a1 | b1 | c1 |", |
3659 | 1 | "+----+----+----+", |
3660 | 1 | "| 1 | 1 | 0 |", |
3661 | 1 | "| 2 | 1 | 0 |", |
3662 | 1 | "| 3 | 1 | 0 |", |
3663 | 1 | "| 4 | 1 | 0 |", |
3664 | 1 | "+----+----+----+", |
3665 | 1 | ]; |
3666 | 1 | let right_batch = [ |
3667 | 1 | "+----+----+----+", |
3668 | 1 | "| a2 | b2 | c2 |", |
3669 | 1 | "+----+----+----+", |
3670 | 1 | "| 10 | 1 | 0 |", |
3671 | 1 | "| 20 | 1 | 0 |", |
3672 | 1 | "| 30 | 1 | 0 |", |
3673 | 1 | "| 40 | 1 | 0 |", |
3674 | 1 | "| 50 | 1 | 0 |", |
3675 | 1 | "+----+----+----+", |
3676 | 1 | ]; |
3677 | 1 | let right_empty = [ |
3678 | 1 | "+----+----+----+", |
3679 | 1 | "| a2 | b2 | c2 |", |
3680 | 1 | "+----+----+----+", |
3681 | 1 | "+----+----+----+", |
3682 | 1 | ]; |
3683 | 1 | let left_empty = [ |
3684 | 1 | "+----+----+----+", |
3685 | 1 | "| a1 | b1 | c1 |", |
3686 | 1 | "+----+----+----+", |
3687 | 1 | "+----+----+----+", |
3688 | 1 | ]; |
3689 | 1 | |
3690 | 1 | // validation of partial join results output for different batch_size setting |
3691 | 9 | for join_type8 in join_types { |
3692 | 160 | for batch_size in (1..21).rev()8 { |
3693 | 160 | let task_ctx = prepare_task_ctx(batch_size); |
3694 | 160 | |
3695 | 160 | let join = join( |
3696 | 160 | Arc::clone(&left), |
3697 | 160 | Arc::clone(&right), |
3698 | 160 | on.clone(), |
3699 | 160 | &join_type, |
3700 | 160 | false, |
3701 | 160 | ) |
3702 | 160 | .unwrap(); |
3703 | 160 | |
3704 | 160 | let stream = join.execute(0, task_ctx).unwrap(); |
3705 | 160 | let batches = common::collect(stream).await0 .unwrap(); |
3706 | 1 | |
3707 | 1 | // For inner/right join expected batch count equals dev_ceil result, |
3708 | 1 | // as there is no need to append non-joined build side data. |
3709 | 1 | // For other join types it'll be div_ceil + 1 -- for additional batch |
3710 | 1 | // containing not visited build side rows (empty in this test case). |
3711 | 160 | let expected_batch_count = match join_type { |
3712 | 1 | JoinType::Inner |
3713 | 1 | | JoinType::Right |
3714 | 1 | | JoinType::RightSemi |
3715 | 1 | | JoinType::RightAnti => { |
3716 | 80 | div_ceil(expected_resultset_records, batch_size) |
3717 | 1 | } |
3718 | 80 | _ => div_ceil(expected_resultset_records, batch_size) + 1, |
3719 | 1 | }; |
3720 | 160 | assert_eq!( |
3721 | 160 | batches.len(), |
3722 | 1 | expected_batch_count, |
3723 | 1 | "expected {} output batches for {} join with batch_size = {}"0 , |
3724 | 1 | expected_batch_count, |
3725 | 1 | join_type, |
3726 | 1 | batch_size |
3727 | 1 | ); |
3728 | 1 | |
3729 | 160 | let expected = match join_type { |
3730 | 20 | JoinType::RightSemi => right_batch.to_vec(), |
3731 | 20 | JoinType::RightAnti => right_empty.to_vec(), |
3732 | 20 | JoinType::LeftSemi => left_batch.to_vec(), |
3733 | 20 | JoinType::LeftAnti => left_empty.to_vec(), |
3734 | 80 | _ => common_result.to_vec(), |
3735 | 1 | }; |
3736 | 160 | assert_batches_eq!(expected, &batches); |
3737 | 1 | } |
3738 | 1 | } |
3739 | 1 | } |
3740 | | |
3741 | | #[tokio::test] |
3742 | 1 | async fn single_partition_join_overallocation() -> Result<()> { |
3743 | 1 | let left = build_table( |
3744 | 1 | ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
3745 | 1 | ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
3746 | 1 | ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
3747 | 1 | ); |
3748 | 1 | let right = build_table( |
3749 | 1 | ("a2", &vec![10, 11]), |
3750 | 1 | ("b2", &vec![12, 13]), |
3751 | 1 | ("c2", &vec![14, 15]), |
3752 | 1 | ); |
3753 | 1 | let on = vec![( |
3754 | 1 | Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) as _, |
3755 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
3756 | 1 | )]; |
3757 | 1 | |
3758 | 1 | let join_types = vec![ |
3759 | 1 | JoinType::Inner, |
3760 | 1 | JoinType::Left, |
3761 | 1 | JoinType::Right, |
3762 | 1 | JoinType::Full, |
3763 | 1 | JoinType::LeftSemi, |
3764 | 1 | JoinType::LeftAnti, |
3765 | 1 | JoinType::RightSemi, |
3766 | 1 | JoinType::RightAnti, |
3767 | 1 | ]; |
3768 | 1 | |
3769 | 9 | for join_type8 in join_types { |
3770 | 8 | let runtime = RuntimeEnvBuilder::new() |
3771 | 8 | .with_memory_limit(100, 1.0) |
3772 | 8 | .build_arc()?0 ; |
3773 | 8 | let task_ctx = TaskContext::default().with_runtime(runtime); |
3774 | 8 | let task_ctx = Arc::new(task_ctx); |
3775 | 1 | |
3776 | 8 | let join = join( |
3777 | 8 | Arc::clone(&left), |
3778 | 8 | Arc::clone(&right), |
3779 | 8 | on.clone(), |
3780 | 8 | &join_type, |
3781 | 8 | false, |
3782 | 8 | )?0 ; |
3783 | 1 | |
3784 | 8 | let stream = join.execute(0, task_ctx)?0 ; |
3785 | 8 | let err = common::collect(stream).await0 .unwrap_err(); |
3786 | 8 | |
3787 | 8 | // Asserting that operator-level reservation attempting to overallocate |
3788 | 8 | assert_contains!( |
3789 | 8 | err.to_string(), |
3790 | 8 | "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput" |
3791 | 8 | ); |
3792 | 1 | } |
3793 | 1 | |
3794 | 1 | Ok(()) |
3795 | 1 | } |
3796 | | |
3797 | | #[tokio::test] |
3798 | 1 | async fn partitioned_join_overallocation() -> Result<()> { |
3799 | 1 | // Prepare partitioned inputs for HashJoinExec |
3800 | 1 | // No need to adjust partitioning, as execution should fail with `Resources exhausted` error |
3801 | 1 | let left_batch = build_table_i32( |
3802 | 1 | ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
3803 | 1 | ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
3804 | 1 | ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
3805 | 1 | ); |
3806 | 1 | let left = Arc::new( |
3807 | 1 | MemoryExec::try_new( |
3808 | 1 | &[vec![left_batch.clone()], vec![left_batch.clone()]], |
3809 | 1 | left_batch.schema(), |
3810 | 1 | None, |
3811 | 1 | ) |
3812 | 1 | .unwrap(), |
3813 | 1 | ); |
3814 | 1 | let right_batch = build_table_i32( |
3815 | 1 | ("a2", &vec![10, 11]), |
3816 | 1 | ("b2", &vec![12, 13]), |
3817 | 1 | ("c2", &vec![14, 15]), |
3818 | 1 | ); |
3819 | 1 | let right = Arc::new( |
3820 | 1 | MemoryExec::try_new( |
3821 | 1 | &[vec![right_batch.clone()], vec![right_batch.clone()]], |
3822 | 1 | right_batch.schema(), |
3823 | 1 | None, |
3824 | 1 | ) |
3825 | 1 | .unwrap(), |
3826 | 1 | ); |
3827 | 1 | let on = vec![( |
3828 | 1 | Arc::new(Column::new_with_schema("b1", &left_batch.schema())?0 ) as _, |
3829 | 1 | Arc::new(Column::new_with_schema("b2", &right_batch.schema())?0 ) as _, |
3830 | 1 | )]; |
3831 | 1 | |
3832 | 1 | let join_types = vec![ |
3833 | 1 | JoinType::Inner, |
3834 | 1 | JoinType::Left, |
3835 | 1 | JoinType::Right, |
3836 | 1 | JoinType::Full, |
3837 | 1 | JoinType::LeftSemi, |
3838 | 1 | JoinType::LeftAnti, |
3839 | 1 | JoinType::RightSemi, |
3840 | 1 | JoinType::RightAnti, |
3841 | 1 | ]; |
3842 | 1 | |
3843 | 9 | for join_type8 in join_types { |
3844 | 8 | let runtime = RuntimeEnvBuilder::new() |
3845 | 8 | .with_memory_limit(100, 1.0) |
3846 | 8 | .build_arc()?0 ; |
3847 | 8 | let session_config = SessionConfig::default().with_batch_size(50); |
3848 | 8 | let task_ctx = TaskContext::default() |
3849 | 8 | .with_session_config(session_config) |
3850 | 8 | .with_runtime(runtime); |
3851 | 8 | let task_ctx = Arc::new(task_ctx); |
3852 | 1 | |
3853 | 8 | let join = HashJoinExec::try_new( |
3854 | 8 | Arc::clone(&left) as Arc<dyn ExecutionPlan>, |
3855 | 8 | Arc::clone(&right) as Arc<dyn ExecutionPlan>, |
3856 | 8 | on.clone(), |
3857 | 8 | None, |
3858 | 8 | &join_type, |
3859 | 8 | None, |
3860 | 8 | PartitionMode::Partitioned, |
3861 | 8 | false, |
3862 | 8 | )?0 ; |
3863 | 1 | |
3864 | 8 | let stream = join.execute(1, task_ctx)?0 ; |
3865 | 8 | let err = common::collect(stream).await0 .unwrap_err(); |
3866 | 8 | |
3867 | 8 | // Asserting that stream-level reservation attempting to overallocate |
3868 | 8 | assert_contains!( |
3869 | 8 | err.to_string(), |
3870 | 8 | "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]" |
3871 | 8 | |
3872 | 8 | ); |
3873 | 1 | } |
3874 | 1 | |
3875 | 1 | Ok(()) |
3876 | 1 | } |
3877 | | |
3878 | 4 | fn build_table_struct( |
3879 | 4 | struct_name: &str, |
3880 | 4 | field_name_and_values: (&str, &Vec<Option<i32>>), |
3881 | 4 | nulls: Option<NullBuffer>, |
3882 | 4 | ) -> Arc<dyn ExecutionPlan> { |
3883 | 4 | let (field_name, values) = field_name_and_values; |
3884 | 4 | let inner_fields = vec![Field::new(field_name, DataType::Int32, true)]; |
3885 | 4 | let schema = Schema::new(vec![Field::new( |
3886 | 4 | struct_name, |
3887 | 4 | DataType::Struct(inner_fields.clone().into()), |
3888 | 4 | nulls.is_some(), |
3889 | 4 | )]); |
3890 | 4 | |
3891 | 4 | let batch = RecordBatch::try_new( |
3892 | 4 | Arc::new(schema), |
3893 | 4 | vec![Arc::new(StructArray::new( |
3894 | 4 | inner_fields.into(), |
3895 | 4 | vec![Arc::new(Int32Array::from(values.clone()))], |
3896 | 4 | nulls, |
3897 | 4 | ))], |
3898 | 4 | ) |
3899 | 4 | .unwrap(); |
3900 | 4 | let schema_ref = batch.schema(); |
3901 | 4 | Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap()) |
3902 | 4 | } |
3903 | | |
3904 | | #[tokio::test] |
3905 | 1 | async fn join_on_struct() -> Result<()> { |
3906 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
3907 | 1 | let left = |
3908 | 1 | build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None); |
3909 | 1 | let right = |
3910 | 1 | build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None); |
3911 | 1 | let on = vec![( |
3912 | 1 | Arc::new(Column::new_with_schema("n1", &left.schema())?0 ) as _, |
3913 | 1 | Arc::new(Column::new_with_schema("n2", &right.schema())?0 ) as _, |
3914 | 1 | )]; |
3915 | 1 | |
3916 | 1 | let (columns, batches) = |
3917 | 1 | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await0 ?0 ; |
3918 | 1 | |
3919 | 1 | assert_eq!(columns, vec!["n1", "n2"]); |
3920 | 1 | |
3921 | 1 | let expected = [ |
3922 | 1 | "+--------+--------+", |
3923 | 1 | "| n1 | n2 |", |
3924 | 1 | "+--------+--------+", |
3925 | 1 | "| {a: } | {a: } |", |
3926 | 1 | "| {a: 1} | {a: 1} |", |
3927 | 1 | "| {a: 2} | {a: 2} |", |
3928 | 1 | "+--------+--------+", |
3929 | 1 | ]; |
3930 | 1 | assert_batches_eq!(expected, &batches); |
3931 | 1 | |
3932 | 1 | Ok(()) |
3933 | 1 | } |
3934 | | |
3935 | | #[tokio::test] |
3936 | 1 | async fn join_on_struct_with_nulls() -> Result<()> { |
3937 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
3938 | 1 | let left = |
3939 | 1 | build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1))); |
3940 | 1 | let right = |
3941 | 1 | build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1))); |
3942 | 1 | let on = vec![( |
3943 | 1 | Arc::new(Column::new_with_schema("n1", &left.schema())?0 ) as _, |
3944 | 1 | Arc::new(Column::new_with_schema("n2", &right.schema())?0 ) as _, |
3945 | 1 | )]; |
3946 | 1 | |
3947 | 1 | let (_, batches_null_eq) = join_collect( |
3948 | 1 | Arc::clone(&left), |
3949 | 1 | Arc::clone(&right), |
3950 | 1 | on.clone(), |
3951 | 1 | &JoinType::Inner, |
3952 | 1 | true, |
3953 | 1 | Arc::clone(&task_ctx), |
3954 | 1 | ) |
3955 | 1 | .await0 ?0 ; |
3956 | 1 | |
3957 | 1 | let expected_null_eq = [ |
3958 | 1 | "+----+----+", |
3959 | 1 | "| n1 | n2 |", |
3960 | 1 | "+----+----+", |
3961 | 1 | "| | |", |
3962 | 1 | "+----+----+", |
3963 | 1 | ]; |
3964 | 1 | assert_batches_eq!(expected_null_eq, &batches_null_eq); |
3965 | 1 | |
3966 | 1 | let (_, batches_null_neq) = |
3967 | 1 | join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await0 ?0 ; |
3968 | 1 | |
3969 | 1 | let expected_null_neq = |
3970 | 1 | ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; |
3971 | 1 | assert_batches_eq!(expected_null_neq, &batches_null_neq); |
3972 | 1 | |
3973 | 1 | Ok(()) |
3974 | 1 | } |
3975 | | |
3976 | | /// Returns the column names on the schema |
3977 | 160 | fn columns(schema: &Schema) -> Vec<String> { |
3978 | 771 | schema.fields().iter().map(|f| f.name().clone()).collect() |
3979 | 160 | } |
3980 | | } |