/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/nested_loop_join.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines the nested loop join plan, it supports all [`JoinType`]. |
19 | | //! The nested loop join can execute in parallel by partitions and it is |
20 | | //! determined by the [`JoinType`]. |
21 | | |
22 | | use std::any::Any; |
23 | | use std::fmt::Formatter; |
24 | | use std::sync::atomic::{AtomicUsize, Ordering}; |
25 | | use std::sync::Arc; |
26 | | use std::task::Poll; |
27 | | |
28 | | use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; |
29 | | use crate::coalesce_partitions::CoalescePartitionsExec; |
30 | | use crate::joins::utils::{ |
31 | | adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, |
32 | | build_join_schema, check_join_is_valid, estimate_join_statistics, |
33 | | get_final_indices_from_bit_map, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, |
34 | | OnceAsync, OnceFut, |
35 | | }; |
36 | | use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; |
37 | | use crate::{ |
38 | | execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, |
39 | | ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, |
40 | | RecordBatchStream, SendableRecordBatchStream, |
41 | | }; |
42 | | |
43 | | use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; |
44 | | use arrow::compute::concat_batches; |
45 | | use arrow::datatypes::{Schema, SchemaRef}; |
46 | | use arrow::record_batch::RecordBatch; |
47 | | use arrow::util::bit_util; |
48 | | use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics}; |
49 | | use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; |
50 | | use datafusion_execution::TaskContext; |
51 | | use datafusion_expr::JoinType; |
52 | | use datafusion_physical_expr::equivalence::join_equivalence_properties; |
53 | | |
54 | | use futures::{ready, Stream, StreamExt, TryStreamExt}; |
55 | | use parking_lot::Mutex; |
56 | | |
57 | | /// Shared bitmap for visited left-side indices |
58 | | type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>; |
59 | | /// Left (build-side) data |
60 | | struct JoinLeftData { |
61 | | /// Build-side data collected to single batch |
62 | | batch: RecordBatch, |
63 | | /// Shared bitmap builder for visited left indices |
64 | | bitmap: SharedBitmapBuilder, |
65 | | /// Counter of running probe-threads, potentially able to update `bitmap` |
66 | | probe_threads_counter: AtomicUsize, |
67 | | /// Memory reservation for tracking batch and bitmap |
68 | | /// Cleared on `JoinLeftData` drop |
69 | | #[allow(dead_code)] |
70 | | reservation: MemoryReservation, |
71 | | } |
72 | | |
73 | | impl JoinLeftData { |
74 | 44 | fn new( |
75 | 44 | batch: RecordBatch, |
76 | 44 | bitmap: SharedBitmapBuilder, |
77 | 44 | probe_threads_counter: AtomicUsize, |
78 | 44 | reservation: MemoryReservation, |
79 | 44 | ) -> Self { |
80 | 44 | Self { |
81 | 44 | batch, |
82 | 44 | bitmap, |
83 | 44 | probe_threads_counter, |
84 | 44 | reservation, |
85 | 44 | } |
86 | 44 | } |
87 | | |
88 | 12.1k | fn batch(&self) -> &RecordBatch { |
89 | 12.1k | &self.batch |
90 | 12.1k | } |
91 | | |
92 | 12.2k | fn bitmap(&self) -> &SharedBitmapBuilder { |
93 | 12.2k | &self.bitmap |
94 | 12.2k | } |
95 | | |
96 | | /// Decrements counter of running threads, and returns `true` |
97 | | /// if caller is the last running thread |
98 | 16 | fn report_probe_completed(&self) -> bool { |
99 | 16 | self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 |
100 | 16 | } |
101 | | } |
102 | | |
103 | | /// NestedLoopJoinExec is build-probe join operator, whose main task is to |
104 | | /// perform joins without any equijoin conditions in `ON` clause. |
105 | | /// |
106 | | /// Execution consists of following phases: |
107 | | /// |
108 | | /// #### 1. Build phase |
109 | | /// Collecting build-side data in memory, by polling all available data from build-side input. |
110 | | /// Due to the absence of equijoin conditions, it's not possible to partition build-side data |
111 | | /// across multiple threads of the operator, so build-side is always collected in a single |
112 | | /// batch shared across all threads. |
113 | | /// The operator always considers LEFT input as build-side input, so it's crucial to adjust |
114 | | /// smaller input to be the LEFT one. Normally this selection is handled by physical optimizer. |
115 | | /// |
116 | | /// #### 2. Probe phase |
117 | | /// Sequentially polling batches from the probe-side input and processing them according to the |
118 | | /// following logic: |
119 | | /// - apply join filter (`ON` clause) to Cartesian product of probe batch and build side data |
120 | | /// -- filter evaluation is executed once per build-side data row |
121 | | /// - update shared bitmap of joined ("visited") build-side row indices, if required -- allows |
122 | | /// to produce unmatched build-side data in case of e.g. LEFT/FULL JOIN after probing phase |
123 | | /// completed |
124 | | /// - perform join index alignment is required -- depending on `JoinType` |
125 | | /// - produce output join batch |
126 | | /// |
127 | | /// Probing phase is executed in parallel, according to probe-side input partitioning -- one |
128 | | /// thread per partition. After probe input is exhausted, each thread **ATTEMPTS** to produce |
129 | | /// unmatched build-side data. |
130 | | /// |
131 | | /// #### 3. Producing unmatched build-side data |
132 | | /// Producing unmatched build-side data as an output batch, after probe input is exhausted. |
133 | | /// This step is also executed in parallel (once per probe input partition), and to avoid |
134 | | /// duplicate output of unmatched data (due to shared nature build-side data), each thread |
135 | | /// "reports" about probe phase completion (which means that "visited" bitmap won't be |
136 | | /// updated anymore), and only the last thread, reporting about completion, will return output. |
137 | | /// |
138 | | #[derive(Debug)] |
139 | | pub struct NestedLoopJoinExec { |
140 | | /// left side |
141 | | pub(crate) left: Arc<dyn ExecutionPlan>, |
142 | | /// right side |
143 | | pub(crate) right: Arc<dyn ExecutionPlan>, |
144 | | /// Filters which are applied while finding matching rows |
145 | | pub(crate) filter: Option<JoinFilter>, |
146 | | /// How the join is performed |
147 | | pub(crate) join_type: JoinType, |
148 | | /// The schema once the join is applied |
149 | | schema: SchemaRef, |
150 | | /// Build-side data |
151 | | inner_table: OnceAsync<JoinLeftData>, |
152 | | /// Information of index and left / right placement of columns |
153 | | column_indices: Vec<ColumnIndex>, |
154 | | /// Execution metrics |
155 | | metrics: ExecutionPlanMetricsSet, |
156 | | /// Cache holding plan properties like equivalences, output partitioning etc. |
157 | | cache: PlanProperties, |
158 | | } |
159 | | |
160 | | impl NestedLoopJoinExec { |
161 | | /// Try to create a new [`NestedLoopJoinExec`] |
162 | 52 | pub fn try_new( |
163 | 52 | left: Arc<dyn ExecutionPlan>, |
164 | 52 | right: Arc<dyn ExecutionPlan>, |
165 | 52 | filter: Option<JoinFilter>, |
166 | 52 | join_type: &JoinType, |
167 | 52 | ) -> Result<Self> { |
168 | 52 | let left_schema = left.schema(); |
169 | 52 | let right_schema = right.schema(); |
170 | 52 | check_join_is_valid(&left_schema, &right_schema, &[])?0 ; |
171 | 52 | let (schema, column_indices) = |
172 | 52 | build_join_schema(&left_schema, &right_schema, join_type); |
173 | 52 | let schema = Arc::new(schema); |
174 | 52 | let cache = |
175 | 52 | Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type); |
176 | 52 | |
177 | 52 | Ok(NestedLoopJoinExec { |
178 | 52 | left, |
179 | 52 | right, |
180 | 52 | filter, |
181 | 52 | join_type: *join_type, |
182 | 52 | schema, |
183 | 52 | inner_table: Default::default(), |
184 | 52 | column_indices, |
185 | 52 | metrics: Default::default(), |
186 | 52 | cache, |
187 | 52 | }) |
188 | 52 | } |
189 | | |
190 | | /// left side |
191 | 0 | pub fn left(&self) -> &Arc<dyn ExecutionPlan> { |
192 | 0 | &self.left |
193 | 0 | } |
194 | | |
195 | | /// right side |
196 | 52 | pub fn right(&self) -> &Arc<dyn ExecutionPlan> { |
197 | 52 | &self.right |
198 | 52 | } |
199 | | |
200 | | /// Filters applied before join output |
201 | 0 | pub fn filter(&self) -> Option<&JoinFilter> { |
202 | 0 | self.filter.as_ref() |
203 | 0 | } |
204 | | |
205 | | /// How the join is performed |
206 | 0 | pub fn join_type(&self) -> &JoinType { |
207 | 0 | &self.join_type |
208 | 0 | } |
209 | | |
210 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
211 | 52 | fn compute_properties( |
212 | 52 | left: &Arc<dyn ExecutionPlan>, |
213 | 52 | right: &Arc<dyn ExecutionPlan>, |
214 | 52 | schema: SchemaRef, |
215 | 52 | join_type: JoinType, |
216 | 52 | ) -> PlanProperties { |
217 | 52 | // Calculate equivalence properties: |
218 | 52 | let eq_properties = join_equivalence_properties( |
219 | 52 | left.equivalence_properties().clone(), |
220 | 52 | right.equivalence_properties().clone(), |
221 | 52 | &join_type, |
222 | 52 | schema, |
223 | 52 | &Self::maintains_input_order(join_type), |
224 | 52 | None, |
225 | 52 | // No on columns in nested loop join |
226 | 52 | &[], |
227 | 52 | ); |
228 | 52 | |
229 | 52 | let output_partitioning = |
230 | 52 | asymmetric_join_output_partitioning(left, right, &join_type); |
231 | 52 | |
232 | 52 | // Determine execution mode: |
233 | 52 | let mut mode = execution_mode_from_children([left, right]); |
234 | 52 | if mode.is_unbounded() { |
235 | 0 | mode = ExecutionMode::PipelineBreaking; |
236 | 52 | } |
237 | | |
238 | 52 | PlanProperties::new(eq_properties, output_partitioning, mode) |
239 | 52 | } |
240 | | |
241 | | /// Returns a vector indicating whether the left and right inputs maintain their order. |
242 | | /// The first element corresponds to the left input, and the second to the right. |
243 | | /// |
244 | | /// The left (build-side) input's order may change, but the right (probe-side) input's |
245 | | /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins. |
246 | | /// |
247 | | /// Maintaining the right input's order helps optimize the nodes down the pipeline |
248 | | /// (See [`ExecutionPlan::maintains_input_order`]). |
249 | | /// |
250 | | /// This is a separate method because it is also called when computing properties, before |
251 | | /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as |
252 | | /// opposed to `Self`, for the same reason. |
253 | 164 | fn maintains_input_order(join_type: JoinType) -> Vec<bool> { |
254 | 164 | vec![ |
255 | | false, |
256 | 28 | matches!( |
257 | 164 | join_type, |
258 | | JoinType::Inner |
259 | | | JoinType::Right |
260 | | | JoinType::RightAnti |
261 | | | JoinType::RightSemi |
262 | | ), |
263 | | ] |
264 | 164 | } |
265 | | } |
266 | | |
267 | | impl DisplayAs for NestedLoopJoinExec { |
268 | 0 | fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { |
269 | 0 | match t { |
270 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
271 | 0 | let display_filter = self.filter.as_ref().map_or_else( |
272 | 0 | || "".to_string(), |
273 | 0 | |f| format!(", filter={}", f.expression()), |
274 | 0 | ); |
275 | 0 | write!( |
276 | 0 | f, |
277 | 0 | "NestedLoopJoinExec: join_type={:?}{}", |
278 | 0 | self.join_type, display_filter |
279 | 0 | ) |
280 | 0 | } |
281 | 0 | } |
282 | 0 | } |
283 | | } |
284 | | |
285 | | impl ExecutionPlan for NestedLoopJoinExec { |
286 | 0 | fn name(&self) -> &'static str { |
287 | 0 | "NestedLoopJoinExec" |
288 | 0 | } |
289 | | |
290 | 0 | fn as_any(&self) -> &dyn Any { |
291 | 0 | self |
292 | 0 | } |
293 | | |
294 | 52 | fn properties(&self) -> &PlanProperties { |
295 | 52 | &self.cache |
296 | 52 | } |
297 | | |
298 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
299 | 0 | vec![ |
300 | 0 | Distribution::SinglePartition, |
301 | 0 | Distribution::UnspecifiedDistribution, |
302 | 0 | ] |
303 | 0 | } |
304 | | |
305 | 112 | fn maintains_input_order(&self) -> Vec<bool> { |
306 | 112 | Self::maintains_input_order(self.join_type) |
307 | 112 | } |
308 | | |
309 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
310 | 0 | vec![&self.left, &self.right] |
311 | 0 | } |
312 | | |
313 | 0 | fn with_new_children( |
314 | 0 | self: Arc<Self>, |
315 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
316 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
317 | 0 | Ok(Arc::new(NestedLoopJoinExec::try_new( |
318 | 0 | Arc::clone(&children[0]), |
319 | 0 | Arc::clone(&children[1]), |
320 | 0 | self.filter.clone(), |
321 | 0 | &self.join_type, |
322 | 0 | )?)) |
323 | 0 | } |
324 | | |
325 | 76 | fn execute( |
326 | 76 | &self, |
327 | 76 | partition: usize, |
328 | 76 | context: Arc<TaskContext>, |
329 | 76 | ) -> Result<SendableRecordBatchStream> { |
330 | 76 | let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); |
331 | 76 | |
332 | 76 | // Initialization reservation for load of inner table |
333 | 76 | let load_reservation = |
334 | 76 | MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) |
335 | 76 | .register(context.memory_pool()); |
336 | 76 | |
337 | 76 | let inner_table = self.inner_table.once(|| { |
338 | 52 | collect_left_input( |
339 | 52 | Arc::clone(&self.left), |
340 | 52 | Arc::clone(&context), |
341 | 52 | join_metrics.clone(), |
342 | 52 | load_reservation, |
343 | 52 | need_produce_result_in_final(self.join_type), |
344 | 52 | self.right().output_partitioning().partition_count(), |
345 | 52 | ) |
346 | 76 | }); |
347 | | |
348 | 76 | let outer_table = self.right.execute(partition, context)?0 ; |
349 | | |
350 | 76 | let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); |
351 | | |
352 | | // Right side has an order and it is maintained during operation. |
353 | 76 | let right_side_ordered = |
354 | 76 | self.maintains_input_order()[1] && self.right.output_ordering().is_some()56 ; |
355 | 76 | Ok(Box::pin(NestedLoopJoinStream { |
356 | 76 | schema: Arc::clone(&self.schema), |
357 | 76 | filter: self.filter.clone(), |
358 | 76 | join_type: self.join_type, |
359 | 76 | outer_table, |
360 | 76 | inner_table, |
361 | 76 | is_exhausted: false, |
362 | 76 | column_indices: self.column_indices.clone(), |
363 | 76 | join_metrics, |
364 | 76 | indices_cache, |
365 | 76 | right_side_ordered, |
366 | 76 | })) |
367 | 76 | } |
368 | | |
369 | 0 | fn metrics(&self) -> Option<MetricsSet> { |
370 | 0 | Some(self.metrics.clone_inner()) |
371 | 0 | } |
372 | | |
373 | 0 | fn statistics(&self) -> Result<Statistics> { |
374 | 0 | estimate_join_statistics( |
375 | 0 | Arc::clone(&self.left), |
376 | 0 | Arc::clone(&self.right), |
377 | 0 | vec![], |
378 | 0 | &self.join_type, |
379 | 0 | &self.schema, |
380 | 0 | ) |
381 | 0 | } |
382 | | } |
383 | | |
384 | | /// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it |
385 | 52 | async fn collect_left_input( |
386 | 52 | input: Arc<dyn ExecutionPlan>, |
387 | 52 | context: Arc<TaskContext>, |
388 | 52 | join_metrics: BuildProbeJoinMetrics, |
389 | 52 | reservation: MemoryReservation, |
390 | 52 | with_visited_left_side: bool, |
391 | 52 | probe_threads_count: usize, |
392 | 52 | ) -> Result<JoinLeftData> { |
393 | 52 | let schema = input.schema(); |
394 | 52 | let merge = if input.output_partitioning().partition_count() != 1 { |
395 | 0 | Arc::new(CoalescePartitionsExec::new(input)) |
396 | | } else { |
397 | 52 | input |
398 | | }; |
399 | 52 | let stream = merge.execute(0, context)?0 ; |
400 | | |
401 | | // Load all batches and count the rows |
402 | 52 | let (batches, metrics, mut reservation44 ) = stream |
403 | 52 | .try_fold( |
404 | 52 | (Vec::new(), join_metrics, reservation), |
405 | 12.1k | |mut acc, batch| async { |
406 | 12.1k | let batch_size = batch.get_array_memory_size(); |
407 | 12.1k | // Reserve memory for incoming batch |
408 | 12.1k | acc.2.try_grow(batch_size)?8 ; |
409 | | // Update metrics |
410 | 12.1k | acc.1.build_mem_used.add(batch_size); |
411 | 12.1k | acc.1.build_input_batches.add(1); |
412 | 12.1k | acc.1.build_input_rows.add(batch.num_rows()); |
413 | 12.1k | // Push batch to output |
414 | 12.1k | acc.0.push(batch); |
415 | 12.1k | Ok(acc) |
416 | 24.2k | }, |
417 | 52 | ) |
418 | 8 | .await0 ?; |
419 | | |
420 | 44 | let merged_batch = concat_batches(&schema, &batches)?0 ; |
421 | | |
422 | | // Reserve memory for visited_left_side bitmap if required by join type |
423 | 44 | let visited_left_side = if with_visited_left_side { |
424 | | // TODO: Replace `ceil` wrapper with stable `div_cell` after |
425 | | // https://github.com/rust-lang/rust/issues/88581 |
426 | 4 | let buffer_size = bit_util::ceil(merged_batch.num_rows(), 8); |
427 | 4 | reservation.try_grow(buffer_size)?0 ; |
428 | 4 | metrics.build_mem_used.add(buffer_size); |
429 | 4 | |
430 | 4 | let mut buffer = BooleanBufferBuilder::new(merged_batch.num_rows()); |
431 | 4 | buffer.append_n(merged_batch.num_rows(), false); |
432 | 4 | buffer |
433 | | } else { |
434 | 40 | BooleanBufferBuilder::new(0) |
435 | | }; |
436 | | |
437 | 44 | Ok(JoinLeftData::new( |
438 | 44 | merged_batch, |
439 | 44 | Mutex::new(visited_left_side), |
440 | 44 | AtomicUsize::new(probe_threads_count), |
441 | 44 | reservation, |
442 | 44 | )) |
443 | 52 | } |
444 | | |
445 | | /// A stream that issues [RecordBatch]es as they arrive from the right of the join. |
446 | | struct NestedLoopJoinStream { |
447 | | /// Input schema |
448 | | schema: Arc<Schema>, |
449 | | /// join filter |
450 | | filter: Option<JoinFilter>, |
451 | | /// type of the join |
452 | | join_type: JoinType, |
453 | | /// the outer table data of the nested loop join |
454 | | outer_table: SendableRecordBatchStream, |
455 | | /// the inner table data of the nested loop join |
456 | | inner_table: OnceFut<JoinLeftData>, |
457 | | /// There is nothing to process anymore and left side is processed in case of full join |
458 | | is_exhausted: bool, |
459 | | /// Information of index and left / right placement of columns |
460 | | column_indices: Vec<ColumnIndex>, |
461 | | // TODO: support null aware equal |
462 | | // null_equals_null: bool |
463 | | /// Join execution metrics |
464 | | join_metrics: BuildProbeJoinMetrics, |
465 | | /// Cache for join indices calculations |
466 | | indices_cache: (UInt64Array, UInt32Array), |
467 | | /// Whether the right side is ordered |
468 | | right_side_ordered: bool, |
469 | | } |
470 | | |
471 | | /// Creates a Cartesian product of two input batches, preserving the order of the right batch, |
472 | | /// and applying a join filter if provided. |
473 | | /// |
474 | | /// # Example |
475 | | /// Input: |
476 | | /// left = [0, 1], right = [0, 1, 2] |
477 | | /// |
478 | | /// Output: |
479 | | /// left_indices = [0, 1, 0, 1, 0, 1], right_indices = [0, 0, 1, 1, 2, 2] |
480 | | /// |
481 | | /// Input: |
482 | | /// left = [0, 1, 2], right = [0, 1, 2, 3], filter = left.a != right.a |
483 | | /// |
484 | | /// Output: |
485 | | /// left_indices = [1, 2, 0, 2, 0, 1, 0, 1, 2], right_indices = [0, 0, 1, 1, 2, 2, 3, 3, 3] |
486 | 12.1k | fn build_join_indices( |
487 | 12.1k | left_batch: &RecordBatch, |
488 | 12.1k | right_batch: &RecordBatch, |
489 | 12.1k | filter: Option<&JoinFilter>, |
490 | 12.1k | indices_cache: &mut (UInt64Array, UInt32Array), |
491 | 12.1k | ) -> Result<(UInt64Array, UInt32Array)> { |
492 | 12.1k | let left_row_count = left_batch.num_rows(); |
493 | 12.1k | let right_row_count = right_batch.num_rows(); |
494 | 12.1k | let output_row_count = left_row_count * right_row_count; |
495 | 12.1k | |
496 | 12.1k | // We always use the same indices before applying the filter, so we can cache them |
497 | 12.1k | let (left_indices_cache, right_indices_cache) = indices_cache; |
498 | 12.1k | let cached_output_row_count = left_indices_cache.len(); |
499 | | |
500 | 12.1k | let (left_indices, right_indices) = |
501 | 12.1k | match output_row_count.cmp(&cached_output_row_count) { |
502 | | std::cmp::Ordering::Equal => { |
503 | | // Reuse the cached indices |
504 | 12.0k | (left_indices_cache.clone(), right_indices_cache.clone()) |
505 | | } |
506 | | std::cmp::Ordering::Less => { |
507 | | // Left_row_count never changes because it's the build side. The changes to the |
508 | | // right_row_count can be handled trivially by taking the first output_row_count |
509 | | // elements of the cache because of how the indices are generated. |
510 | | // (See the Ordering::Greater match arm) |
511 | 0 | ( |
512 | 0 | left_indices_cache.slice(0, output_row_count), |
513 | 0 | right_indices_cache.slice(0, output_row_count), |
514 | 0 | ) |
515 | | } |
516 | | std::cmp::Ordering::Greater => { |
517 | | // Rebuild the indices cache |
518 | | |
519 | | // Produces 0, 1, 2, 0, 1, 2, 0, 1, 2, ... |
520 | 44 | *left_indices_cache = UInt64Array::from_iter_values( |
521 | 13.2M | (0..output_row_count as u64).map(|i| i % left_row_count as u64), |
522 | 44 | ); |
523 | 44 | |
524 | 44 | // Produces 0, 0, 0, 1, 1, 1, 2, 2, 2, ... |
525 | 44 | *right_indices_cache = UInt32Array::from_iter_values( |
526 | 13.2M | (0..output_row_count as u32).map(|i| i / left_row_count as u32), |
527 | 44 | ); |
528 | 44 | |
529 | 44 | (left_indices_cache.clone(), right_indices_cache.clone()) |
530 | | } |
531 | | }; |
532 | | |
533 | 12.1k | if let Some(filter) = filter { |
534 | 12.1k | apply_join_filter_to_indices( |
535 | 12.1k | left_batch, |
536 | 12.1k | right_batch, |
537 | 12.1k | left_indices, |
538 | 12.1k | right_indices, |
539 | 12.1k | filter, |
540 | 12.1k | JoinSide::Left, |
541 | 12.1k | ) |
542 | | } else { |
543 | 0 | Ok((left_indices, right_indices)) |
544 | | } |
545 | 12.1k | } |
546 | | |
547 | | impl NestedLoopJoinStream { |
548 | 12.2k | fn poll_next_impl( |
549 | 12.2k | &mut self, |
550 | 12.2k | cx: &mut std::task::Context<'_>, |
551 | 12.2k | ) -> Poll<Option<Result<RecordBatch>>> { |
552 | 12.2k | // all left row |
553 | 12.2k | let build_timer = self.join_metrics.build_time.timer(); |
554 | 12.2k | let left_data12.2k = match ready!0 (self.inner_table.get_shared(cx)) { |
555 | 12.2k | Ok(data) => data, |
556 | 8 | Err(e) => return Poll::Ready(Some(Err(e))), |
557 | | }; |
558 | 12.2k | build_timer.done(); |
559 | 12.2k | |
560 | 12.2k | // Get or initialize visited_left_side bitmap if required by join type |
561 | 12.2k | let visited_left_side = left_data.bitmap(); |
562 | 12.2k | |
563 | 12.2k | // Check is_exhausted before polling the outer_table, such that when the outer table |
564 | 12.2k | // does not support `FusedStream`, Self will not poll it again |
565 | 12.2k | if self.is_exhausted { |
566 | 4 | return Poll::Ready(None); |
567 | 12.2k | } |
568 | 12.2k | |
569 | 12.2k | self.outer_table |
570 | 12.2k | .poll_next_unpin(cx) |
571 | 12.2k | .map(|maybe_batch| m12.2k atch maybe_batch12.1k { |
572 | 12.1k | Some(Ok(right_batch)) => { |
573 | 12.1k | // Setting up timer & updating input metrics |
574 | 12.1k | self.join_metrics.input_batches.add(1); |
575 | 12.1k | self.join_metrics.input_rows.add(right_batch.num_rows()); |
576 | 12.1k | let timer = self.join_metrics.join_time.timer(); |
577 | 12.1k | |
578 | 12.1k | let result = join_left_and_right_batch( |
579 | 12.1k | left_data.batch(), |
580 | 12.1k | &right_batch, |
581 | 12.1k | self.join_type, |
582 | 12.1k | self.filter.as_ref(), |
583 | 12.1k | &self.column_indices, |
584 | 12.1k | &self.schema, |
585 | 12.1k | visited_left_side, |
586 | 12.1k | &mut self.indices_cache, |
587 | 12.1k | self.right_side_ordered, |
588 | 12.1k | ); |
589 | | |
590 | | // Recording time & updating output metrics |
591 | 12.1k | if let Ok(batch) = &result { |
592 | 12.1k | timer.done(); |
593 | 12.1k | self.join_metrics.output_batches.add(1); |
594 | 12.1k | self.join_metrics.output_rows.add(batch.num_rows()); |
595 | 12.1k | }0 |
596 | | |
597 | 12.1k | Some(result) |
598 | | } |
599 | 0 | Some(err) => Some(err), |
600 | | None => { |
601 | 68 | if need_produce_result_in_final(self.join_type) { |
602 | | // At this stage `visited_left_side` won't be updated, so it's |
603 | | // safe to report about probe completion. |
604 | | // |
605 | | // Setting `is_exhausted` / returning None will prevent from |
606 | | // multiple calls of `report_probe_completed()` |
607 | 16 | if !left_data.report_probe_completed() { |
608 | 12 | self.is_exhausted = true; |
609 | 12 | return None; |
610 | 4 | }; |
611 | 4 | |
612 | 4 | // Only setting up timer, input is exhausted |
613 | 4 | let timer = self.join_metrics.join_time.timer(); |
614 | 4 | // use the global left bitmap to produce the left indices and right indices |
615 | 4 | let (left_side, right_side) = |
616 | 4 | get_final_indices_from_shared_bitmap( |
617 | 4 | visited_left_side, |
618 | 4 | self.join_type, |
619 | 4 | ); |
620 | 4 | let empty_right_batch = |
621 | 4 | RecordBatch::new_empty(self.outer_table.schema()); |
622 | 4 | // use the left and right indices to produce the batch result |
623 | 4 | let result = build_batch_from_indices( |
624 | 4 | &self.schema, |
625 | 4 | left_data.batch(), |
626 | 4 | &empty_right_batch, |
627 | 4 | &left_side, |
628 | 4 | &right_side, |
629 | 4 | &self.column_indices, |
630 | 4 | JoinSide::Left, |
631 | 4 | ); |
632 | 4 | self.is_exhausted = true; |
633 | | |
634 | | // Recording time & updating output metrics |
635 | 4 | if let Ok(batch) = &result { |
636 | 4 | timer.done(); |
637 | 4 | self.join_metrics.output_batches.add(1); |
638 | 4 | self.join_metrics.output_rows.add(batch.num_rows()); |
639 | 4 | }0 |
640 | | |
641 | 4 | Some(result) |
642 | | } else { |
643 | | // end of the join loop |
644 | 52 | None |
645 | | } |
646 | | } |
647 | 12.2k | }12.2k ) |
648 | 12.2k | } |
649 | | } |
650 | | |
651 | | #[allow(clippy::too_many_arguments)] |
652 | 12.1k | fn join_left_and_right_batch( |
653 | 12.1k | left_batch: &RecordBatch, |
654 | 12.1k | right_batch: &RecordBatch, |
655 | 12.1k | join_type: JoinType, |
656 | 12.1k | filter: Option<&JoinFilter>, |
657 | 12.1k | column_indices: &[ColumnIndex], |
658 | 12.1k | schema: &Schema, |
659 | 12.1k | visited_left_side: &SharedBitmapBuilder, |
660 | 12.1k | indices_cache: &mut (UInt64Array, UInt32Array), |
661 | 12.1k | right_side_ordered: bool, |
662 | 12.1k | ) -> Result<RecordBatch> { |
663 | 12.1k | let (left_side, right_side) = |
664 | 12.1k | build_join_indices(left_batch, right_batch, filter, indices_cache).map_err( |
665 | 12.1k | |e| { |
666 | 0 | exec_datafusion_err!( |
667 | 0 | "Fail to build join indices in NestedLoopJoinExec, error: {e}" |
668 | 0 | ) |
669 | 12.1k | }, |
670 | 12.1k | )?0 ; |
671 | | |
672 | | // set the left bitmap |
673 | | // and only full join need the left bitmap |
674 | 12.1k | if need_produce_result_in_final(join_type) { |
675 | 4 | let mut bitmap = visited_left_side.lock(); |
676 | 4 | left_side.values().iter().for_each(|x| { |
677 | 4 | bitmap.set_bit(*x as usize, true); |
678 | 4 | }); |
679 | 12.1k | } |
680 | | // adjust the two side indices base on the join type |
681 | 12.1k | let (left_side, right_side) = adjust_indices_by_join_type( |
682 | 12.1k | left_side, |
683 | 12.1k | right_side, |
684 | 12.1k | 0..right_batch.num_rows(), |
685 | 12.1k | join_type, |
686 | 12.1k | right_side_ordered, |
687 | 12.1k | ); |
688 | 12.1k | |
689 | 12.1k | build_batch_from_indices( |
690 | 12.1k | schema, |
691 | 12.1k | left_batch, |
692 | 12.1k | right_batch, |
693 | 12.1k | &left_side, |
694 | 12.1k | &right_side, |
695 | 12.1k | column_indices, |
696 | 12.1k | JoinSide::Left, |
697 | 12.1k | ) |
698 | 12.1k | } |
699 | | |
700 | 4 | fn get_final_indices_from_shared_bitmap( |
701 | 4 | shared_bitmap: &SharedBitmapBuilder, |
702 | 4 | join_type: JoinType, |
703 | 4 | ) -> (UInt64Array, UInt32Array) { |
704 | 4 | let bitmap = shared_bitmap.lock(); |
705 | 4 | get_final_indices_from_bit_map(&bitmap, join_type) |
706 | 4 | } |
707 | | |
708 | | impl Stream for NestedLoopJoinStream { |
709 | | type Item = Result<RecordBatch>; |
710 | | |
711 | 12.2k | fn poll_next( |
712 | 12.2k | mut self: std::pin::Pin<&mut Self>, |
713 | 12.2k | cx: &mut std::task::Context<'_>, |
714 | 12.2k | ) -> Poll<Option<Self::Item>> { |
715 | 12.2k | self.poll_next_impl(cx) |
716 | 12.2k | } |
717 | | } |
718 | | |
719 | | impl RecordBatchStream for NestedLoopJoinStream { |
720 | 0 | fn schema(&self) -> SchemaRef { |
721 | 0 | Arc::clone(&self.schema) |
722 | 0 | } |
723 | | } |
724 | | |
725 | | #[cfg(test)] |
726 | | mod tests { |
727 | | use super::*; |
728 | | use crate::{ |
729 | | common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, |
730 | | test::build_table_i32, |
731 | | }; |
732 | | |
733 | | use arrow::datatypes::{DataType, Field}; |
734 | | use arrow_array::Int32Array; |
735 | | use arrow_schema::SortOptions; |
736 | | use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; |
737 | | use datafusion_execution::runtime_env::RuntimeEnvBuilder; |
738 | | use datafusion_expr::Operator; |
739 | | use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; |
740 | | use datafusion_physical_expr::{Partitioning, PhysicalExpr}; |
741 | | use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; |
742 | | |
743 | | use rstest::rstest; |
744 | | |
745 | 90 | fn build_table( |
746 | 90 | a: (&str, &Vec<i32>), |
747 | 90 | b: (&str, &Vec<i32>), |
748 | 90 | c: (&str, &Vec<i32>), |
749 | 90 | batch_size: Option<usize>, |
750 | 90 | sorted_column_names: Vec<&str>, |
751 | 90 | ) -> Arc<dyn ExecutionPlan> { |
752 | 90 | let batch = build_table_i32(a, b, c); |
753 | 90 | let schema = batch.schema(); |
754 | | |
755 | 90 | let batches = if let Some(batch_size72 ) = batch_size { |
756 | 72 | let num_batches = batch.num_rows().div_ceil(batch_size); |
757 | 72 | (0..num_batches) |
758 | 24.2k | .map(|i| { |
759 | 24.2k | let start = i * batch_size; |
760 | 24.2k | let remaining_rows = batch.num_rows() - start; |
761 | 24.2k | batch.slice(start, batch_size.min(remaining_rows)) |
762 | 24.2k | }) |
763 | 72 | .collect::<Vec<_>>() |
764 | | } else { |
765 | 18 | vec![batch] |
766 | | }; |
767 | | |
768 | 90 | let mut exec = |
769 | 90 | MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); |
770 | 90 | if !sorted_column_names.is_empty() { |
771 | 36 | let mut sort_info = Vec::new(); |
772 | 144 | for name108 in sorted_column_names { |
773 | 108 | let index = schema.index_of(name).unwrap(); |
774 | 108 | let sort_expr = PhysicalSortExpr { |
775 | 108 | expr: Arc::new(Column::new(name, index)), |
776 | 108 | options: SortOptions { |
777 | 108 | descending: false, |
778 | 108 | nulls_first: false, |
779 | 108 | }, |
780 | 108 | }; |
781 | 108 | sort_info.push(sort_expr); |
782 | 108 | } |
783 | 36 | exec = exec.with_sort_information(vec![sort_info]); |
784 | 54 | } |
785 | | |
786 | 90 | Arc::new(exec) |
787 | 90 | } |
788 | | |
789 | 8 | fn build_left_table() -> Arc<dyn ExecutionPlan> { |
790 | 8 | build_table( |
791 | 8 | ("a1", &vec![5, 9, 11]), |
792 | 8 | ("b1", &vec![5, 8, 8]), |
793 | 8 | ("c1", &vec![50, 90, 110]), |
794 | 8 | None, |
795 | 8 | Vec::new(), |
796 | 8 | ) |
797 | 8 | } |
798 | | |
799 | 8 | fn build_right_table() -> Arc<dyn ExecutionPlan> { |
800 | 8 | build_table( |
801 | 8 | ("a2", &vec![12, 2, 10]), |
802 | 8 | ("b2", &vec![10, 2, 10]), |
803 | 8 | ("c2", &vec![40, 80, 100]), |
804 | 8 | None, |
805 | 8 | Vec::new(), |
806 | 8 | ) |
807 | 8 | } |
808 | | |
809 | 9 | fn prepare_join_filter() -> JoinFilter { |
810 | 9 | let column_indices = vec![ |
811 | 9 | ColumnIndex { |
812 | 9 | index: 1, |
813 | 9 | side: JoinSide::Left, |
814 | 9 | }, |
815 | 9 | ColumnIndex { |
816 | 9 | index: 1, |
817 | 9 | side: JoinSide::Right, |
818 | 9 | }, |
819 | 9 | ]; |
820 | 9 | let intermediate_schema = Schema::new(vec![ |
821 | 9 | Field::new("x", DataType::Int32, true), |
822 | 9 | Field::new("x", DataType::Int32, true), |
823 | 9 | ]); |
824 | 9 | // left.b1!=8 |
825 | 9 | let left_filter = Arc::new(BinaryExpr::new( |
826 | 9 | Arc::new(Column::new("x", 0)), |
827 | 9 | Operator::NotEq, |
828 | 9 | Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), |
829 | 9 | )) as Arc<dyn PhysicalExpr>; |
830 | 9 | // right.b2!=10 |
831 | 9 | let right_filter = Arc::new(BinaryExpr::new( |
832 | 9 | Arc::new(Column::new("x", 1)), |
833 | 9 | Operator::NotEq, |
834 | 9 | Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), |
835 | 9 | )) as Arc<dyn PhysicalExpr>; |
836 | 9 | // filter = left.b1!=8 and right.b2!=10 |
837 | 9 | // after filter: |
838 | 9 | // left table: |
839 | 9 | // ("a1", &vec![5]), |
840 | 9 | // ("b1", &vec![5]), |
841 | 9 | // ("c1", &vec![50]), |
842 | 9 | // right table: |
843 | 9 | // ("a2", &vec![12, 2]), |
844 | 9 | // ("b2", &vec![10, 2]), |
845 | 9 | // ("c2", &vec![40, 80]), |
846 | 9 | let filter_expression = |
847 | 9 | Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter)) |
848 | 9 | as Arc<dyn PhysicalExpr>; |
849 | 9 | |
850 | 9 | JoinFilter::new(filter_expression, column_indices, intermediate_schema) |
851 | 9 | } |
852 | | |
853 | 16 | async fn multi_partitioned_join_collect( |
854 | 16 | left: Arc<dyn ExecutionPlan>, |
855 | 16 | right: Arc<dyn ExecutionPlan>, |
856 | 16 | join_type: &JoinType, |
857 | 16 | join_filter: Option<JoinFilter>, |
858 | 16 | context: Arc<TaskContext>, |
859 | 16 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
860 | 16 | let partition_count = 4; |
861 | | |
862 | | // Redistributing right input |
863 | 16 | let right = Arc::new(RepartitionExec::try_new( |
864 | 16 | right, |
865 | 16 | Partitioning::RoundRobinBatch(partition_count), |
866 | 16 | )?0 ) as Arc<dyn ExecutionPlan>; |
867 | | |
868 | | // Use the required distribution for nested loop join to test partition data |
869 | 16 | let nested_loop_join = |
870 | 16 | NestedLoopJoinExec::try_new(left, right, join_filter, join_type)?0 ; |
871 | 16 | let columns = columns(&nested_loop_join.schema()); |
872 | 16 | let mut batches = vec![]; |
873 | 40 | for i in 0..partition_count16 { |
874 | 40 | let stream = nested_loop_join.execute(i, Arc::clone(&context))?0 ; |
875 | 40 | let more_batches32 = common::collect(stream).await8 ?8 ; |
876 | 32 | batches.extend( |
877 | 32 | more_batches |
878 | 32 | .into_iter() |
879 | 32 | .filter(|b| b.num_rows() > 012 ) |
880 | 32 | .collect::<Vec<_>>(), |
881 | 32 | ); |
882 | 32 | } |
883 | 8 | Ok((columns, batches)) |
884 | 16 | } |
885 | | |
886 | | #[tokio::test] |
887 | 1 | async fn join_inner_with_filter() -> Result<()> { |
888 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
889 | 1 | let left = build_left_table(); |
890 | 1 | let right = build_right_table(); |
891 | 1 | let filter = prepare_join_filter(); |
892 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
893 | 1 | left, |
894 | 1 | right, |
895 | 1 | &JoinType::Inner, |
896 | 1 | Some(filter), |
897 | 1 | task_ctx, |
898 | 1 | ) |
899 | 1 | .await?0 ; |
900 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
901 | 1 | let expected = [ |
902 | 1 | "+----+----+----+----+----+----+", |
903 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
904 | 1 | "+----+----+----+----+----+----+", |
905 | 1 | "| 5 | 5 | 50 | 2 | 2 | 80 |", |
906 | 1 | "+----+----+----+----+----+----+", |
907 | 1 | ]; |
908 | 1 | |
909 | 1 | assert_batches_sorted_eq!(expected, &batches); |
910 | 1 | |
911 | 1 | Ok(()) |
912 | 1 | } |
913 | | |
914 | | #[tokio::test] |
915 | 1 | async fn join_left_with_filter() -> Result<()> { |
916 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
917 | 1 | let left = build_left_table(); |
918 | 1 | let right = build_right_table(); |
919 | 1 | |
920 | 1 | let filter = prepare_join_filter(); |
921 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
922 | 1 | left, |
923 | 1 | right, |
924 | 1 | &JoinType::Left, |
925 | 1 | Some(filter), |
926 | 1 | task_ctx, |
927 | 1 | ) |
928 | 1 | .await?0 ; |
929 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
930 | 1 | let expected = [ |
931 | 1 | "+----+----+-----+----+----+----+", |
932 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
933 | 1 | "+----+----+-----+----+----+----+", |
934 | 1 | "| 11 | 8 | 110 | | | |", |
935 | 1 | "| 5 | 5 | 50 | 2 | 2 | 80 |", |
936 | 1 | "| 9 | 8 | 90 | | | |", |
937 | 1 | "+----+----+-----+----+----+----+", |
938 | 1 | ]; |
939 | 1 | |
940 | 1 | assert_batches_sorted_eq!(expected, &batches); |
941 | 1 | |
942 | 1 | Ok(()) |
943 | 1 | } |
944 | | |
945 | | #[tokio::test] |
946 | 1 | async fn join_right_with_filter() -> Result<()> { |
947 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
948 | 1 | let left = build_left_table(); |
949 | 1 | let right = build_right_table(); |
950 | 1 | |
951 | 1 | let filter = prepare_join_filter(); |
952 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
953 | 1 | left, |
954 | 1 | right, |
955 | 1 | &JoinType::Right, |
956 | 1 | Some(filter), |
957 | 1 | task_ctx, |
958 | 1 | ) |
959 | 1 | .await?0 ; |
960 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
961 | 1 | let expected = [ |
962 | 1 | "+----+----+----+----+----+-----+", |
963 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
964 | 1 | "+----+----+----+----+----+-----+", |
965 | 1 | "| | | | 10 | 10 | 100 |", |
966 | 1 | "| | | | 12 | 10 | 40 |", |
967 | 1 | "| 5 | 5 | 50 | 2 | 2 | 80 |", |
968 | 1 | "+----+----+----+----+----+-----+", |
969 | 1 | ]; |
970 | 1 | |
971 | 1 | assert_batches_sorted_eq!(expected, &batches); |
972 | 1 | |
973 | 1 | Ok(()) |
974 | 1 | } |
975 | | |
976 | | #[tokio::test] |
977 | 1 | async fn join_full_with_filter() -> Result<()> { |
978 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
979 | 1 | let left = build_left_table(); |
980 | 1 | let right = build_right_table(); |
981 | 1 | |
982 | 1 | let filter = prepare_join_filter(); |
983 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
984 | 1 | left, |
985 | 1 | right, |
986 | 1 | &JoinType::Full, |
987 | 1 | Some(filter), |
988 | 1 | task_ctx, |
989 | 1 | ) |
990 | 1 | .await?0 ; |
991 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
992 | 1 | let expected = [ |
993 | 1 | "+----+----+-----+----+----+-----+", |
994 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
995 | 1 | "+----+----+-----+----+----+-----+", |
996 | 1 | "| | | | 10 | 10 | 100 |", |
997 | 1 | "| | | | 12 | 10 | 40 |", |
998 | 1 | "| 11 | 8 | 110 | | | |", |
999 | 1 | "| 5 | 5 | 50 | 2 | 2 | 80 |", |
1000 | 1 | "| 9 | 8 | 90 | | | |", |
1001 | 1 | "+----+----+-----+----+----+-----+", |
1002 | 1 | ]; |
1003 | 1 | |
1004 | 1 | assert_batches_sorted_eq!(expected, &batches); |
1005 | 1 | |
1006 | 1 | Ok(()) |
1007 | 1 | } |
1008 | | |
1009 | | #[tokio::test] |
1010 | 1 | async fn join_left_semi_with_filter() -> Result<()> { |
1011 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1012 | 1 | let left = build_left_table(); |
1013 | 1 | let right = build_right_table(); |
1014 | 1 | |
1015 | 1 | let filter = prepare_join_filter(); |
1016 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
1017 | 1 | left, |
1018 | 1 | right, |
1019 | 1 | &JoinType::LeftSemi, |
1020 | 1 | Some(filter), |
1021 | 1 | task_ctx, |
1022 | 1 | ) |
1023 | 1 | .await?0 ; |
1024 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1"]); |
1025 | 1 | let expected = [ |
1026 | 1 | "+----+----+----+", |
1027 | 1 | "| a1 | b1 | c1 |", |
1028 | 1 | "+----+----+----+", |
1029 | 1 | "| 5 | 5 | 50 |", |
1030 | 1 | "+----+----+----+", |
1031 | 1 | ]; |
1032 | 1 | |
1033 | 1 | assert_batches_sorted_eq!(expected, &batches); |
1034 | 1 | |
1035 | 1 | Ok(()) |
1036 | 1 | } |
1037 | | |
1038 | | #[tokio::test] |
1039 | 1 | async fn join_left_anti_with_filter() -> Result<()> { |
1040 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1041 | 1 | let left = build_left_table(); |
1042 | 1 | let right = build_right_table(); |
1043 | 1 | |
1044 | 1 | let filter = prepare_join_filter(); |
1045 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
1046 | 1 | left, |
1047 | 1 | right, |
1048 | 1 | &JoinType::LeftAnti, |
1049 | 1 | Some(filter), |
1050 | 1 | task_ctx, |
1051 | 1 | ) |
1052 | 1 | .await?0 ; |
1053 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1"]); |
1054 | 1 | let expected = [ |
1055 | 1 | "+----+----+-----+", |
1056 | 1 | "| a1 | b1 | c1 |", |
1057 | 1 | "+----+----+-----+", |
1058 | 1 | "| 11 | 8 | 110 |", |
1059 | 1 | "| 9 | 8 | 90 |", |
1060 | 1 | "+----+----+-----+", |
1061 | 1 | ]; |
1062 | 1 | |
1063 | 1 | assert_batches_sorted_eq!(expected, &batches); |
1064 | 1 | |
1065 | 1 | Ok(()) |
1066 | 1 | } |
1067 | | |
1068 | | #[tokio::test] |
1069 | 1 | async fn join_right_semi_with_filter() -> Result<()> { |
1070 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1071 | 1 | let left = build_left_table(); |
1072 | 1 | let right = build_right_table(); |
1073 | 1 | |
1074 | 1 | let filter = prepare_join_filter(); |
1075 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
1076 | 1 | left, |
1077 | 1 | right, |
1078 | 1 | &JoinType::RightSemi, |
1079 | 1 | Some(filter), |
1080 | 1 | task_ctx, |
1081 | 1 | ) |
1082 | 1 | .await?0 ; |
1083 | 1 | assert_eq!(columns, vec!["a2", "b2", "c2"]); |
1084 | 1 | let expected = [ |
1085 | 1 | "+----+----+----+", |
1086 | 1 | "| a2 | b2 | c2 |", |
1087 | 1 | "+----+----+----+", |
1088 | 1 | "| 2 | 2 | 80 |", |
1089 | 1 | "+----+----+----+", |
1090 | 1 | ]; |
1091 | 1 | |
1092 | 1 | assert_batches_sorted_eq!(expected, &batches); |
1093 | 1 | |
1094 | 1 | Ok(()) |
1095 | 1 | } |
1096 | | |
1097 | | #[tokio::test] |
1098 | 1 | async fn join_right_anti_with_filter() -> Result<()> { |
1099 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
1100 | 1 | let left = build_left_table(); |
1101 | 1 | let right = build_right_table(); |
1102 | 1 | |
1103 | 1 | let filter = prepare_join_filter(); |
1104 | 1 | let (columns, batches) = multi_partitioned_join_collect( |
1105 | 1 | left, |
1106 | 1 | right, |
1107 | 1 | &JoinType::RightAnti, |
1108 | 1 | Some(filter), |
1109 | 1 | task_ctx, |
1110 | 1 | ) |
1111 | 1 | .await?0 ; |
1112 | 1 | assert_eq!(columns, vec!["a2", "b2", "c2"]); |
1113 | 1 | let expected = [ |
1114 | 1 | "+----+----+-----+", |
1115 | 1 | "| a2 | b2 | c2 |", |
1116 | 1 | "+----+----+-----+", |
1117 | 1 | "| 10 | 10 | 100 |", |
1118 | 1 | "| 12 | 10 | 40 |", |
1119 | 1 | "+----+----+-----+", |
1120 | 1 | ]; |
1121 | 1 | |
1122 | 1 | assert_batches_sorted_eq!(expected, &batches); |
1123 | 1 | |
1124 | 1 | Ok(()) |
1125 | 1 | } |
1126 | | |
1127 | | #[tokio::test] |
1128 | 1 | async fn test_overallocation() -> Result<()> { |
1129 | 1 | let left = build_table( |
1130 | 1 | ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
1131 | 1 | ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
1132 | 1 | ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
1133 | 1 | None, |
1134 | 1 | Vec::new(), |
1135 | 1 | ); |
1136 | 1 | let right = build_table( |
1137 | 1 | ("a2", &vec![10, 11]), |
1138 | 1 | ("b2", &vec![12, 13]), |
1139 | 1 | ("c2", &vec![14, 15]), |
1140 | 1 | None, |
1141 | 1 | Vec::new(), |
1142 | 1 | ); |
1143 | 1 | let filter = prepare_join_filter(); |
1144 | 1 | |
1145 | 1 | let join_types = vec![ |
1146 | 1 | JoinType::Inner, |
1147 | 1 | JoinType::Left, |
1148 | 1 | JoinType::Right, |
1149 | 1 | JoinType::Full, |
1150 | 1 | JoinType::LeftSemi, |
1151 | 1 | JoinType::LeftAnti, |
1152 | 1 | JoinType::RightSemi, |
1153 | 1 | JoinType::RightAnti, |
1154 | 1 | ]; |
1155 | 1 | |
1156 | 9 | for join_type8 in join_types { |
1157 | 8 | let runtime = RuntimeEnvBuilder::new() |
1158 | 8 | .with_memory_limit(100, 1.0) |
1159 | 8 | .build_arc()?0 ; |
1160 | 8 | let task_ctx = TaskContext::default().with_runtime(runtime); |
1161 | 8 | let task_ctx = Arc::new(task_ctx); |
1162 | 1 | |
1163 | 8 | let err = multi_partitioned_join_collect( |
1164 | 8 | Arc::clone(&left), |
1165 | 8 | Arc::clone(&right), |
1166 | 8 | &join_type, |
1167 | 8 | Some(filter.clone()), |
1168 | 8 | task_ctx, |
1169 | 8 | ) |
1170 | 1 | .await0 |
1171 | 8 | .unwrap_err(); |
1172 | 8 | |
1173 | 8 | assert_contains!( |
1174 | 8 | err.to_string(), |
1175 | 8 | "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" |
1176 | 8 | ); |
1177 | 1 | } |
1178 | 1 | |
1179 | 1 | Ok(()) |
1180 | 1 | } |
1181 | | |
1182 | 36 | fn prepare_mod_join_filter() -> JoinFilter { |
1183 | 36 | let column_indices = vec![ |
1184 | 36 | ColumnIndex { |
1185 | 36 | index: 1, |
1186 | 36 | side: JoinSide::Left, |
1187 | 36 | }, |
1188 | 36 | ColumnIndex { |
1189 | 36 | index: 1, |
1190 | 36 | side: JoinSide::Right, |
1191 | 36 | }, |
1192 | 36 | ]; |
1193 | 36 | let intermediate_schema = Schema::new(vec![ |
1194 | 36 | Field::new("x", DataType::Int32, true), |
1195 | 36 | Field::new("x", DataType::Int32, true), |
1196 | 36 | ]); |
1197 | 36 | |
1198 | 36 | // left.b1 % 3 |
1199 | 36 | let left_mod = Arc::new(BinaryExpr::new( |
1200 | 36 | Arc::new(Column::new("x", 0)), |
1201 | 36 | Operator::Modulo, |
1202 | 36 | Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), |
1203 | 36 | )) as Arc<dyn PhysicalExpr>; |
1204 | 36 | // left.b1 % 3 != 0 |
1205 | 36 | let left_filter = Arc::new(BinaryExpr::new( |
1206 | 36 | left_mod, |
1207 | 36 | Operator::NotEq, |
1208 | 36 | Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), |
1209 | 36 | )) as Arc<dyn PhysicalExpr>; |
1210 | 36 | |
1211 | 36 | // right.b2 % 5 |
1212 | 36 | let right_mod = Arc::new(BinaryExpr::new( |
1213 | 36 | Arc::new(Column::new("x", 1)), |
1214 | 36 | Operator::Modulo, |
1215 | 36 | Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), |
1216 | 36 | )) as Arc<dyn PhysicalExpr>; |
1217 | 36 | // right.b2 % 5 != 0 |
1218 | 36 | let right_filter = Arc::new(BinaryExpr::new( |
1219 | 36 | right_mod, |
1220 | 36 | Operator::NotEq, |
1221 | 36 | Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), |
1222 | 36 | )) as Arc<dyn PhysicalExpr>; |
1223 | 36 | // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0 |
1224 | 36 | let filter_expression = |
1225 | 36 | Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter)) |
1226 | 36 | as Arc<dyn PhysicalExpr>; |
1227 | 36 | |
1228 | 36 | JoinFilter::new(filter_expression, column_indices, intermediate_schema) |
1229 | 36 | } |
1230 | | |
1231 | 72 | fn generate_columns(num_columns: usize, num_rows: usize) -> Vec<Vec<i32>> { |
1232 | 72.0k | let column = (1..=num_rows).map(|x| x as i32).collect(); |
1233 | 72 | vec![column; num_columns] |
1234 | 72 | } |
1235 | | |
1236 | 36 | #[rstest] |
1237 | | #[tokio::test] |
1238 | | async fn join_maintains_right_order( |
1239 | | #[values( |
1240 | | JoinType::Inner, |
1241 | | JoinType::Right, |
1242 | | JoinType::RightAnti, |
1243 | | JoinType::RightSemi |
1244 | | )] |
1245 | | join_type: JoinType, |
1246 | | #[values(1, 100, 1000)] left_batch_size: usize, |
1247 | | #[values(1, 100, 1000)] right_batch_size: usize, |
1248 | | ) -> Result<()> { |
1249 | | let left_columns = generate_columns(3, 1000); |
1250 | | let left = build_table( |
1251 | | ("a1", &left_columns[0]), |
1252 | | ("b1", &left_columns[1]), |
1253 | | ("c1", &left_columns[2]), |
1254 | | Some(left_batch_size), |
1255 | | Vec::new(), |
1256 | | ); |
1257 | | |
1258 | | let right_columns = generate_columns(3, 1000); |
1259 | | let right = build_table( |
1260 | | ("a2", &right_columns[0]), |
1261 | | ("b2", &right_columns[1]), |
1262 | | ("c2", &right_columns[2]), |
1263 | | Some(right_batch_size), |
1264 | | vec!["a2", "b2", "c2"], |
1265 | | ); |
1266 | | |
1267 | | let filter = prepare_mod_join_filter(); |
1268 | | |
1269 | | let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new( |
1270 | | left, |
1271 | | Arc::clone(&right), |
1272 | | Some(filter), |
1273 | | &join_type, |
1274 | | )?) as Arc<dyn ExecutionPlan>; |
1275 | | assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]); |
1276 | | |
1277 | | let right_column_indices = match join_type { |
1278 | | JoinType::Inner | JoinType::Right => vec![3, 4, 5], |
1279 | | JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2], |
1280 | | _ => unreachable!(), |
1281 | | }; |
1282 | | |
1283 | | let right_ordering = right.output_ordering().unwrap(); |
1284 | | let join_ordering = nested_loop_join.output_ordering().unwrap(); |
1285 | | for (right, join) in right_ordering.iter().zip(join_ordering.iter()) { |
1286 | | let right_column = right.expr.as_any().downcast_ref::<Column>().unwrap(); |
1287 | | let join_column = join.expr.as_any().downcast_ref::<Column>().unwrap(); |
1288 | | assert_eq!(join_column.name(), join_column.name()); |
1289 | | assert_eq!( |
1290 | | right_column_indices[right_column.index()], |
1291 | | join_column.index() |
1292 | | ); |
1293 | | assert_eq!(right.options, join.options); |
1294 | | } |
1295 | | |
1296 | | let batches = nested_loop_join |
1297 | | .execute(0, Arc::new(TaskContext::default()))? |
1298 | | .try_collect::<Vec<_>>() |
1299 | | .await?; |
1300 | | |
1301 | | // Make sure that the order of the right side is maintained |
1302 | | let mut prev_values = [i32::MIN, i32::MIN, i32::MIN]; |
1303 | | |
1304 | | for (batch_index, batch) in batches.iter().enumerate() { |
1305 | | let columns: Vec<_> = right_column_indices |
1306 | | .iter() |
1307 | 36.3k | .map(|&i| { |
1308 | 36.3k | batch |
1309 | 36.3k | .column(i) |
1310 | 36.3k | .as_any() |
1311 | 36.3k | .downcast_ref::<Int32Array>() |
1312 | 36.3k | .unwrap() |
1313 | 36.3k | }) |
1314 | | .collect(); |
1315 | | |
1316 | | for row in 0..batch.num_rows() { |
1317 | | let current_values = [ |
1318 | | columns[0].value(row), |
1319 | | columns[1].value(row), |
1320 | | columns[2].value(row), |
1321 | | ]; |
1322 | | assert!( |
1323 | | current_values |
1324 | | .into_iter() |
1325 | | .zip(prev_values) |
1326 | 28.8M | .all(|(current, prev)| current >= prev), |
1327 | | "batch_index: {} row: {} current: {:?}, prev: {:?}", |
1328 | | batch_index, |
1329 | | row, |
1330 | | current_values, |
1331 | | prev_values |
1332 | | ); |
1333 | | prev_values = current_values; |
1334 | | } |
1335 | | } |
1336 | | |
1337 | | Ok(()) |
1338 | | } |
1339 | | |
1340 | | /// Returns the column names on the schema |
1341 | 16 | fn columns(schema: &Schema) -> Vec<String> { |
1342 | 72 | schema.fields().iter().map(|f| f.name().clone()).collect() |
1343 | 16 | } |
1344 | | } |