/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/sort_merge_join.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines the Sort-Merge join execution plan. |
19 | | //! A Sort-Merge join plan consumes two sorted children plan and produces |
20 | | //! joined output by given join type and other options. |
21 | | //! Sort-Merge join feature is currently experimental. |
22 | | |
23 | | use std::any::Any; |
24 | | use std::cmp::Ordering; |
25 | | use std::collections::{HashMap, VecDeque}; |
26 | | use std::fmt::Formatter; |
27 | | use std::fs::File; |
28 | | use std::io::BufReader; |
29 | | use std::mem; |
30 | | use std::ops::Range; |
31 | | use std::pin::Pin; |
32 | | use std::sync::Arc; |
33 | | use std::task::{Context, Poll}; |
34 | | |
35 | | use arrow::array::*; |
36 | | use arrow::compute::{self, concat_batches, take, SortOptions}; |
37 | | use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; |
38 | | use arrow::error::ArrowError; |
39 | | use arrow::ipc::reader::FileReader; |
40 | | use arrow_array::types::UInt64Type; |
41 | | use futures::{Stream, StreamExt}; |
42 | | use hashbrown::HashSet; |
43 | | |
44 | | use datafusion_common::{ |
45 | | exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, |
46 | | Result, |
47 | | }; |
48 | | use datafusion_execution::disk_manager::RefCountedTempFile; |
49 | | use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; |
50 | | use datafusion_execution::runtime_env::RuntimeEnv; |
51 | | use datafusion_execution::TaskContext; |
52 | | use datafusion_physical_expr::equivalence::join_equivalence_properties; |
53 | | use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; |
54 | | use datafusion_physical_expr_common::sort_expr::LexRequirement; |
55 | | |
56 | | use crate::expressions::PhysicalSortExpr; |
57 | | use crate::joins::utils::{ |
58 | | build_join_schema, check_join_is_valid, estimate_join_statistics, |
59 | | symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, |
60 | | }; |
61 | | use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; |
62 | | use crate::spill::spill_record_batches; |
63 | | use crate::{ |
64 | | execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, |
65 | | ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, |
66 | | RecordBatchStream, SendableRecordBatchStream, Statistics, |
67 | | }; |
68 | | |
69 | | /// join execution plan executes partitions in parallel and combines them into a set of |
70 | | /// partitions. |
71 | | #[derive(Debug)] |
72 | | pub struct SortMergeJoinExec { |
73 | | /// Left sorted joining execution plan |
74 | | pub left: Arc<dyn ExecutionPlan>, |
75 | | /// Right sorting joining execution plan |
76 | | pub right: Arc<dyn ExecutionPlan>, |
77 | | /// Set of common columns used to join on |
78 | | pub on: JoinOn, |
79 | | /// Filters which are applied while finding matching rows |
80 | | pub filter: Option<JoinFilter>, |
81 | | /// How the join is performed |
82 | | pub join_type: JoinType, |
83 | | /// The schema once the join is applied |
84 | | schema: SchemaRef, |
85 | | /// Execution metrics |
86 | | metrics: ExecutionPlanMetricsSet, |
87 | | /// The left SortExpr |
88 | | left_sort_exprs: Vec<PhysicalSortExpr>, |
89 | | /// The right SortExpr |
90 | | right_sort_exprs: Vec<PhysicalSortExpr>, |
91 | | /// Sort options of join columns used in sorting left and right execution plans |
92 | | pub sort_options: Vec<SortOptions>, |
93 | | /// If null_equals_null is true, null == null else null != null |
94 | | pub null_equals_null: bool, |
95 | | /// Cache holding plan properties like equivalences, output partitioning etc. |
96 | | cache: PlanProperties, |
97 | | } |
98 | | |
99 | | impl SortMergeJoinExec { |
100 | | /// Tries to create a new [SortMergeJoinExec]. |
101 | | /// The inputs are sorted using `sort_options` are applied to the columns in the `on` |
102 | | /// # Error |
103 | | /// This function errors when it is not possible to join the left and right sides on keys `on`. |
104 | 79 | pub fn try_new( |
105 | 79 | left: Arc<dyn ExecutionPlan>, |
106 | 79 | right: Arc<dyn ExecutionPlan>, |
107 | 79 | on: JoinOn, |
108 | 79 | filter: Option<JoinFilter>, |
109 | 79 | join_type: JoinType, |
110 | 79 | sort_options: Vec<SortOptions>, |
111 | 79 | null_equals_null: bool, |
112 | 79 | ) -> Result<Self> { |
113 | 79 | let left_schema = left.schema(); |
114 | 79 | let right_schema = right.schema(); |
115 | 79 | |
116 | 79 | if join_type == JoinType::RightSemi { |
117 | 0 | return not_impl_err!( |
118 | 0 | "SortMergeJoinExec does not support JoinType::RightSemi" |
119 | 0 | ); |
120 | 79 | } |
121 | 79 | |
122 | 79 | check_join_is_valid(&left_schema, &right_schema, &on)?0 ; |
123 | 79 | if sort_options.len() != on.len() { |
124 | 0 | return plan_err!( |
125 | 0 | "Expected number of sort options: {}, actual: {}", |
126 | 0 | on.len(), |
127 | 0 | sort_options.len() |
128 | 0 | ); |
129 | 79 | } |
130 | 79 | |
131 | 79 | let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on |
132 | 79 | .iter() |
133 | 79 | .zip(sort_options.iter()) |
134 | 84 | .map(|((l, r), sort_op)| { |
135 | 84 | let left = PhysicalSortExpr { |
136 | 84 | expr: Arc::clone(l), |
137 | 84 | options: *sort_op, |
138 | 84 | }; |
139 | 84 | let right = PhysicalSortExpr { |
140 | 84 | expr: Arc::clone(r), |
141 | 84 | options: *sort_op, |
142 | 84 | }; |
143 | 84 | (left, right) |
144 | 84 | }) |
145 | 79 | .unzip(); |
146 | 79 | |
147 | 79 | let schema = |
148 | 79 | Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); |
149 | 79 | let cache = |
150 | 79 | Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); |
151 | 79 | Ok(Self { |
152 | 79 | left, |
153 | 79 | right, |
154 | 79 | on, |
155 | 79 | filter, |
156 | 79 | join_type, |
157 | 79 | schema, |
158 | 79 | metrics: ExecutionPlanMetricsSet::new(), |
159 | 79 | left_sort_exprs, |
160 | 79 | right_sort_exprs, |
161 | 79 | sort_options, |
162 | 79 | null_equals_null, |
163 | 79 | cache, |
164 | 79 | }) |
165 | 79 | } |
166 | | |
167 | | /// Get probe side (e.g streaming side) information for this sort merge join. |
168 | | /// In current implementation, probe side is determined according to join type. |
169 | 158 | pub fn probe_side(join_type: &JoinType) -> JoinSide { |
170 | 158 | // When output schema contains only the right side, probe side is right. |
171 | 158 | // Otherwise probe side is the left side. |
172 | 158 | match join_type { |
173 | | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { |
174 | 26 | JoinSide::Right |
175 | | } |
176 | | JoinType::Inner |
177 | | | JoinType::Left |
178 | | | JoinType::Full |
179 | | | JoinType::LeftAnti |
180 | 132 | | JoinType::LeftSemi => JoinSide::Left, |
181 | | } |
182 | 158 | } |
183 | | |
184 | | /// Calculate order preservation flags for this sort merge join. |
185 | 79 | fn maintains_input_order(join_type: JoinType) -> Vec<bool> { |
186 | 79 | match join_type { |
187 | 19 | JoinType::Inner => vec![true, false], |
188 | 35 | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], |
189 | | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { |
190 | 13 | vec![false, true] |
191 | | } |
192 | 12 | _ => vec![false, false], |
193 | | } |
194 | 79 | } |
195 | | |
196 | | /// Set of common columns used to join on |
197 | 0 | pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { |
198 | 0 | &self.on |
199 | 0 | } |
200 | | |
201 | 0 | pub fn right(&self) -> &Arc<dyn ExecutionPlan> { |
202 | 0 | &self.right |
203 | 0 | } |
204 | | |
205 | 0 | pub fn join_type(&self) -> JoinType { |
206 | 0 | self.join_type |
207 | 0 | } |
208 | | |
209 | 0 | pub fn left(&self) -> &Arc<dyn ExecutionPlan> { |
210 | 0 | &self.left |
211 | 0 | } |
212 | | |
213 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
214 | 79 | fn compute_properties( |
215 | 79 | left: &Arc<dyn ExecutionPlan>, |
216 | 79 | right: &Arc<dyn ExecutionPlan>, |
217 | 79 | schema: SchemaRef, |
218 | 79 | join_type: JoinType, |
219 | 79 | join_on: JoinOnRef, |
220 | 79 | ) -> PlanProperties { |
221 | 79 | // Calculate equivalence properties: |
222 | 79 | let eq_properties = join_equivalence_properties( |
223 | 79 | left.equivalence_properties().clone(), |
224 | 79 | right.equivalence_properties().clone(), |
225 | 79 | &join_type, |
226 | 79 | schema, |
227 | 79 | &Self::maintains_input_order(join_type), |
228 | 79 | Some(Self::probe_side(&join_type)), |
229 | 79 | join_on, |
230 | 79 | ); |
231 | 79 | |
232 | 79 | let output_partitioning = |
233 | 79 | symmetric_join_output_partitioning(left, right, &join_type); |
234 | 79 | |
235 | 79 | // Determine execution mode: |
236 | 79 | let mode = execution_mode_from_children([left, right]); |
237 | 79 | |
238 | 79 | PlanProperties::new(eq_properties, output_partitioning, mode) |
239 | 79 | } |
240 | | } |
241 | | |
242 | | impl DisplayAs for SortMergeJoinExec { |
243 | 0 | fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { |
244 | 0 | match t { |
245 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
246 | 0 | let on = self |
247 | 0 | .on |
248 | 0 | .iter() |
249 | 0 | .map(|(c1, c2)| format!("({}, {})", c1, c2)) |
250 | 0 | .collect::<Vec<String>>() |
251 | 0 | .join(", "); |
252 | 0 | write!( |
253 | 0 | f, |
254 | 0 | "SortMergeJoin: join_type={:?}, on=[{}]{}", |
255 | 0 | self.join_type, |
256 | 0 | on, |
257 | 0 | self.filter.as_ref().map_or("".to_string(), |f| format!( |
258 | 0 | ", filter={}", |
259 | 0 | f.expression() |
260 | 0 | )) |
261 | 0 | ) |
262 | 0 | } |
263 | 0 | } |
264 | 0 | } |
265 | | } |
266 | | |
267 | | impl ExecutionPlan for SortMergeJoinExec { |
268 | 0 | fn name(&self) -> &'static str { |
269 | 0 | "SortMergeJoinExec" |
270 | 0 | } |
271 | | |
272 | 0 | fn as_any(&self) -> &dyn Any { |
273 | 0 | self |
274 | 0 | } |
275 | | |
276 | 19 | fn properties(&self) -> &PlanProperties { |
277 | 19 | &self.cache |
278 | 19 | } |
279 | | |
280 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
281 | 0 | let (left_expr, right_expr) = self |
282 | 0 | .on |
283 | 0 | .iter() |
284 | 0 | .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) |
285 | 0 | .unzip(); |
286 | 0 | vec![ |
287 | 0 | Distribution::HashPartitioned(left_expr), |
288 | 0 | Distribution::HashPartitioned(right_expr), |
289 | 0 | ] |
290 | 0 | } |
291 | | |
292 | 0 | fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> { |
293 | 0 | vec![ |
294 | 0 | Some(PhysicalSortRequirement::from_sort_exprs( |
295 | 0 | &self.left_sort_exprs, |
296 | 0 | )), |
297 | 0 | Some(PhysicalSortRequirement::from_sort_exprs( |
298 | 0 | &self.right_sort_exprs, |
299 | 0 | )), |
300 | 0 | ] |
301 | 0 | } |
302 | | |
303 | 0 | fn maintains_input_order(&self) -> Vec<bool> { |
304 | 0 | Self::maintains_input_order(self.join_type) |
305 | 0 | } |
306 | | |
307 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
308 | 0 | vec![&self.left, &self.right] |
309 | 0 | } |
310 | | |
311 | 0 | fn with_new_children( |
312 | 0 | self: Arc<Self>, |
313 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
314 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
315 | 0 | match &children[..] { |
316 | 0 | [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( |
317 | 0 | Arc::clone(left), |
318 | 0 | Arc::clone(right), |
319 | 0 | self.on.clone(), |
320 | 0 | self.filter.clone(), |
321 | 0 | self.join_type, |
322 | 0 | self.sort_options.clone(), |
323 | 0 | self.null_equals_null, |
324 | 0 | )?)), |
325 | 0 | _ => internal_err!("SortMergeJoin wrong number of children"), |
326 | | } |
327 | 0 | } |
328 | | |
329 | 79 | fn execute( |
330 | 79 | &self, |
331 | 79 | partition: usize, |
332 | 79 | context: Arc<TaskContext>, |
333 | 79 | ) -> Result<SendableRecordBatchStream> { |
334 | 79 | let left_partitions = self.left.output_partitioning().partition_count(); |
335 | 79 | let right_partitions = self.right.output_partitioning().partition_count(); |
336 | 79 | if left_partitions != right_partitions { |
337 | 0 | return internal_err!( |
338 | 0 | "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ |
339 | 0 | consider using RepartitionExec" |
340 | 0 | ); |
341 | 79 | } |
342 | 79 | let (on_left, on_right) = self.on.iter().cloned().unzip(); |
343 | 79 | let (streamed, buffered, on_streamed, on_buffered) = |
344 | 79 | if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { |
345 | 66 | ( |
346 | 66 | Arc::clone(&self.left), |
347 | 66 | Arc::clone(&self.right), |
348 | 66 | on_left, |
349 | 66 | on_right, |
350 | 66 | ) |
351 | | } else { |
352 | 13 | ( |
353 | 13 | Arc::clone(&self.right), |
354 | 13 | Arc::clone(&self.left), |
355 | 13 | on_right, |
356 | 13 | on_left, |
357 | 13 | ) |
358 | | }; |
359 | | |
360 | | // execute children plans |
361 | 79 | let streamed = streamed.execute(partition, Arc::clone(&context))?0 ; |
362 | 79 | let buffered = buffered.execute(partition, Arc::clone(&context))?0 ; |
363 | | |
364 | | // create output buffer |
365 | 79 | let batch_size = context.session_config().batch_size(); |
366 | 79 | |
367 | 79 | // create memory reservation |
368 | 79 | let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]")) |
369 | 79 | .register(context.memory_pool()); |
370 | 79 | |
371 | 79 | // create join stream |
372 | 79 | Ok(Box::pin(SMJStream::try_new( |
373 | 79 | Arc::clone(&self.schema), |
374 | 79 | self.sort_options.clone(), |
375 | 79 | self.null_equals_null, |
376 | 79 | streamed, |
377 | 79 | buffered, |
378 | 79 | on_streamed, |
379 | 79 | on_buffered, |
380 | 79 | self.filter.clone(), |
381 | 79 | self.join_type, |
382 | 79 | batch_size, |
383 | 79 | SortMergeJoinMetrics::new(partition, &self.metrics), |
384 | 79 | reservation, |
385 | 79 | context.runtime_env(), |
386 | 79 | )?0 )) |
387 | 79 | } |
388 | | |
389 | 240 | fn metrics(&self) -> Option<MetricsSet> { |
390 | 240 | Some(self.metrics.clone_inner()) |
391 | 240 | } |
392 | | |
393 | 0 | fn statistics(&self) -> Result<Statistics> { |
394 | 0 | // TODO stats: it is not possible in general to know the output size of joins |
395 | 0 | // There are some special cases though, for example: |
396 | 0 | // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` |
397 | 0 | estimate_join_statistics( |
398 | 0 | Arc::clone(&self.left), |
399 | 0 | Arc::clone(&self.right), |
400 | 0 | self.on.clone(), |
401 | 0 | &self.join_type, |
402 | 0 | &self.schema, |
403 | 0 | ) |
404 | 0 | } |
405 | | } |
406 | | |
407 | | /// Metrics for SortMergeJoinExec |
408 | | #[allow(dead_code)] |
409 | | struct SortMergeJoinMetrics { |
410 | | /// Total time for joining probe-side batches to the build-side batches |
411 | | join_time: metrics::Time, |
412 | | /// Number of batches consumed by this operator |
413 | | input_batches: metrics::Count, |
414 | | /// Number of rows consumed by this operator |
415 | | input_rows: metrics::Count, |
416 | | /// Number of batches produced by this operator |
417 | | output_batches: metrics::Count, |
418 | | /// Number of rows produced by this operator |
419 | | output_rows: metrics::Count, |
420 | | /// Peak memory used for buffered data. |
421 | | /// Calculated as sum of peak memory values across partitions |
422 | | peak_mem_used: metrics::Gauge, |
423 | | /// count of spills during the execution of the operator |
424 | | spill_count: Count, |
425 | | /// total spilled bytes during the execution of the operator |
426 | | spilled_bytes: Count, |
427 | | /// total spilled rows during the execution of the operator |
428 | | spilled_rows: Count, |
429 | | } |
430 | | |
431 | | impl SortMergeJoinMetrics { |
432 | | #[allow(dead_code)] |
433 | 79 | pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { |
434 | 79 | let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); |
435 | 79 | let input_batches = |
436 | 79 | MetricBuilder::new(metrics).counter("input_batches", partition); |
437 | 79 | let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); |
438 | 79 | let output_batches = |
439 | 79 | MetricBuilder::new(metrics).counter("output_batches", partition); |
440 | 79 | let output_rows = MetricBuilder::new(metrics).output_rows(partition); |
441 | 79 | let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); |
442 | 79 | let spill_count = MetricBuilder::new(metrics).spill_count(partition); |
443 | 79 | let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition); |
444 | 79 | let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition); |
445 | 79 | |
446 | 79 | Self { |
447 | 79 | join_time, |
448 | 79 | input_batches, |
449 | 79 | input_rows, |
450 | 79 | output_batches, |
451 | 79 | output_rows, |
452 | 79 | peak_mem_used, |
453 | 79 | spill_count, |
454 | 79 | spilled_bytes, |
455 | 79 | spilled_rows, |
456 | 79 | } |
457 | 79 | } |
458 | | } |
459 | | |
460 | | /// State of SMJ stream |
461 | | #[derive(Debug, PartialEq, Eq)] |
462 | | enum SMJState { |
463 | | /// Init joining with a new streamed row or a new buffered batches |
464 | | Init, |
465 | | /// Polling one streamed row or one buffered batch, or both |
466 | | Polling, |
467 | | /// Joining polled data and making output |
468 | | JoinOutput, |
469 | | /// No more output |
470 | | Exhausted, |
471 | | } |
472 | | |
473 | | /// State of streamed data stream |
474 | | #[derive(Debug, PartialEq, Eq)] |
475 | | enum StreamedState { |
476 | | /// Init polling |
477 | | Init, |
478 | | /// Polling one streamed row |
479 | | Polling, |
480 | | /// Ready to produce one streamed row |
481 | | Ready, |
482 | | /// No more streamed row |
483 | | Exhausted, |
484 | | } |
485 | | |
486 | | /// State of buffered data stream |
487 | | #[derive(Debug, PartialEq, Eq)] |
488 | | enum BufferedState { |
489 | | /// Init polling |
490 | | Init, |
491 | | /// Polling first row in the next batch |
492 | | PollingFirst, |
493 | | /// Polling rest rows in the next batch |
494 | | PollingRest, |
495 | | /// Ready to produce one batch |
496 | | Ready, |
497 | | /// No more buffered batches |
498 | | Exhausted, |
499 | | } |
500 | | |
501 | | /// Represents a chunk of joined data from streamed and buffered side |
502 | | struct StreamedJoinedChunk { |
503 | | /// Index of batch in buffered_data |
504 | | buffered_batch_idx: Option<usize>, |
505 | | /// Array builder for streamed indices |
506 | | streamed_indices: UInt64Builder, |
507 | | /// Array builder for buffered indices |
508 | | /// This could contain nulls if the join is null-joined |
509 | | buffered_indices: UInt64Builder, |
510 | | } |
511 | | |
512 | | struct StreamedBatch { |
513 | | /// The streamed record batch |
514 | | pub batch: RecordBatch, |
515 | | /// The index of row in the streamed batch to compare with buffered batches |
516 | | pub idx: usize, |
517 | | /// The join key arrays of streamed batch which are used to compare with buffered batches |
518 | | /// and to produce output. They are produced by evaluating `on` expressions. |
519 | | pub join_arrays: Vec<ArrayRef>, |
520 | | /// Chunks of indices from buffered side (may be nulls) joined to streamed |
521 | | pub output_indices: Vec<StreamedJoinedChunk>, |
522 | | /// Index of currently scanned batch from buffered data |
523 | | pub buffered_batch_idx: Option<usize>, |
524 | | /// Indices that found a match for the given join filter |
525 | | /// Used for semi joins to keep track the streaming index which got a join filter match |
526 | | /// and already emitted to the output. |
527 | | pub join_filter_matched_idxs: HashSet<u64>, |
528 | | } |
529 | | |
530 | | impl StreamedBatch { |
531 | 130 | fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self { |
532 | 130 | let join_arrays = join_arrays(&batch, on_column); |
533 | 130 | StreamedBatch { |
534 | 130 | batch, |
535 | 130 | idx: 0, |
536 | 130 | join_arrays, |
537 | 130 | output_indices: vec![], |
538 | 130 | buffered_batch_idx: None, |
539 | 130 | join_filter_matched_idxs: HashSet::new(), |
540 | 130 | } |
541 | 130 | } |
542 | | |
543 | 79 | fn new_empty(schema: SchemaRef) -> Self { |
544 | 79 | StreamedBatch { |
545 | 79 | batch: RecordBatch::new_empty(schema), |
546 | 79 | idx: 0, |
547 | 79 | join_arrays: vec![], |
548 | 79 | output_indices: vec![], |
549 | 79 | buffered_batch_idx: None, |
550 | 79 | join_filter_matched_idxs: HashSet::new(), |
551 | 79 | } |
552 | 79 | } |
553 | | |
554 | | /// Appends new pair consisting of current streamed index and `buffered_idx` |
555 | | /// index of buffered batch with `buffered_batch_idx` index. |
556 | 693 | fn append_output_pair( |
557 | 693 | &mut self, |
558 | 693 | buffered_batch_idx: Option<usize>, |
559 | 693 | buffered_idx: Option<usize>, |
560 | 693 | ) { |
561 | 693 | // If no current chunk exists or current chunk is not for current buffered batch, |
562 | 693 | // create a new chunk |
563 | 693 | if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx312 |
564 | 504 | { |
565 | 504 | self.output_indices.push(StreamedJoinedChunk { |
566 | 504 | buffered_batch_idx, |
567 | 504 | streamed_indices: UInt64Builder::with_capacity(1), |
568 | 504 | buffered_indices: UInt64Builder::with_capacity(1), |
569 | 504 | }); |
570 | 504 | self.buffered_batch_idx = buffered_batch_idx; |
571 | 504 | };189 |
572 | 693 | let current_chunk = self.output_indices.last_mut().unwrap(); |
573 | 693 | |
574 | 693 | // Append index of streamed batch and index of buffered batch into current chunk |
575 | 693 | current_chunk.streamed_indices.append_value(self.idx as u64); |
576 | 693 | if let Some(idx600 ) = buffered_idx { |
577 | 600 | current_chunk.buffered_indices.append_value(idx as u64); |
578 | 600 | } else { |
579 | 93 | current_chunk.buffered_indices.append_null(); |
580 | 93 | } |
581 | 693 | } |
582 | | } |
583 | | |
584 | | /// A buffered batch that contains contiguous rows with same join key |
585 | | #[derive(Debug)] |
586 | | struct BufferedBatch { |
587 | | /// The buffered record batch |
588 | | /// None if the batch spilled to disk th |
589 | | pub batch: Option<RecordBatch>, |
590 | | /// The range in which the rows share the same join key |
591 | | pub range: Range<usize>, |
592 | | /// Array refs of the join key |
593 | | pub join_arrays: Vec<ArrayRef>, |
594 | | /// Buffered joined index (null joining buffered) |
595 | | pub null_joined: Vec<usize>, |
596 | | /// Size estimation used for reserving / releasing memory |
597 | | pub size_estimation: usize, |
598 | | /// The indices of buffered batch that failed the join filter. |
599 | | /// This is a map between buffered row index and a boolean value indicating whether all joined row |
600 | | /// of the buffered row failed the join filter. |
601 | | /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. |
602 | | pub join_filter_failed_map: HashMap<u64, bool>, |
603 | | /// Current buffered batch number of rows. Equal to batch.num_rows() |
604 | | /// but if batch is spilled to disk this property is preferable |
605 | | /// and less expensive |
606 | | pub num_rows: usize, |
607 | | /// An optional temp spill file name on the disk if the batch spilled |
608 | | /// None by default |
609 | | /// Some(fileName) if the batch spilled to the disk |
610 | | pub spill_file: Option<RefCountedTempFile>, |
611 | | } |
612 | | |
613 | | impl BufferedBatch { |
614 | 130 | fn new( |
615 | 130 | batch: RecordBatch, |
616 | 130 | range: Range<usize>, |
617 | 130 | on_column: &[PhysicalExprRef], |
618 | 130 | ) -> Self { |
619 | 130 | let join_arrays = join_arrays(&batch, on_column); |
620 | 130 | |
621 | 130 | // Estimation is calculated as |
622 | 130 | // inner batch size |
623 | 130 | // + join keys size |
624 | 130 | // + worst case null_joined (as vector capacity * element size) |
625 | 130 | // + Range size |
626 | 130 | // + size of this estimation |
627 | 130 | let size_estimation = batch.get_array_memory_size() |
628 | 130 | + join_arrays |
629 | 130 | .iter() |
630 | 135 | .map(|arr| arr.get_array_memory_size()) |
631 | 130 | .sum::<usize>() |
632 | 130 | + batch.num_rows().next_power_of_two() * mem::size_of::<usize>() |
633 | 130 | + mem::size_of::<Range<usize>>() |
634 | 130 | + mem::size_of::<usize>(); |
635 | 130 | |
636 | 130 | let num_rows = batch.num_rows(); |
637 | 130 | BufferedBatch { |
638 | 130 | batch: Some(batch), |
639 | 130 | range, |
640 | 130 | join_arrays, |
641 | 130 | null_joined: vec![], |
642 | 130 | size_estimation, |
643 | 130 | join_filter_failed_map: HashMap::new(), |
644 | 130 | num_rows, |
645 | 130 | spill_file: None, |
646 | 130 | } |
647 | 130 | } |
648 | | } |
649 | | |
650 | | /// Sort-merge join stream that consumes streamed and buffered data stream |
651 | | /// and produces joined output |
652 | | struct SMJStream { |
653 | | /// Current state of the stream |
654 | | pub state: SMJState, |
655 | | /// Output schema |
656 | | pub schema: SchemaRef, |
657 | | /// Sort options of join columns used to sort streamed and buffered data stream |
658 | | pub sort_options: Vec<SortOptions>, |
659 | | /// null == null? |
660 | | pub null_equals_null: bool, |
661 | | /// Input schema of streamed |
662 | | pub streamed_schema: SchemaRef, |
663 | | /// Input schema of buffered |
664 | | pub buffered_schema: SchemaRef, |
665 | | /// Streamed data stream |
666 | | pub streamed: SendableRecordBatchStream, |
667 | | /// Buffered data stream |
668 | | pub buffered: SendableRecordBatchStream, |
669 | | /// Current processing record batch of streamed |
670 | | pub streamed_batch: StreamedBatch, |
671 | | /// Current buffered data |
672 | | pub buffered_data: BufferedData, |
673 | | /// (used in outer join) Is current streamed row joined at least once? |
674 | | pub streamed_joined: bool, |
675 | | /// (used in outer join) Is current buffered batches joined at least once? |
676 | | pub buffered_joined: bool, |
677 | | /// State of streamed |
678 | | pub streamed_state: StreamedState, |
679 | | /// State of buffered |
680 | | pub buffered_state: BufferedState, |
681 | | /// The comparison result of current streamed row and buffered batches |
682 | | pub current_ordering: Ordering, |
683 | | /// Join key columns of streamed |
684 | | pub on_streamed: Vec<PhysicalExprRef>, |
685 | | /// Join key columns of buffered |
686 | | pub on_buffered: Vec<PhysicalExprRef>, |
687 | | /// optional join filter |
688 | | pub filter: Option<JoinFilter>, |
689 | | /// Staging output array builders |
690 | | pub output_record_batches: Vec<RecordBatch>, |
691 | | /// Staging output size, including output batches and staging joined results. |
692 | | /// Increased when we put rows into buffer and decreased after we actually output batches. |
693 | | /// Used to trigger output when sufficient rows are ready |
694 | | pub output_size: usize, |
695 | | /// Target output batch size |
696 | | pub batch_size: usize, |
697 | | /// How the join is performed |
698 | | pub join_type: JoinType, |
699 | | /// Metrics |
700 | | pub join_metrics: SortMergeJoinMetrics, |
701 | | /// Memory reservation |
702 | | pub reservation: MemoryReservation, |
703 | | /// Runtime env |
704 | | pub runtime_env: Arc<RuntimeEnv>, |
705 | | } |
706 | | |
707 | | impl RecordBatchStream for SMJStream { |
708 | 0 | fn schema(&self) -> SchemaRef { |
709 | 0 | Arc::clone(&self.schema) |
710 | 0 | } |
711 | | } |
712 | | |
713 | | impl Stream for SMJStream { |
714 | | type Item = Result<RecordBatch>; |
715 | | |
716 | 429 | fn poll_next( |
717 | 429 | mut self: Pin<&mut Self>, |
718 | 429 | cx: &mut Context<'_>, |
719 | 429 | ) -> Poll<Option<Self::Item>> { |
720 | 429 | let join_time = self.join_metrics.join_time.clone(); |
721 | 429 | let _timer = join_time.timer(); |
722 | | |
723 | | loop { |
724 | 2.28k | match &self.state { |
725 | | SMJState::Init => { |
726 | 648 | let streamed_exhausted = |
727 | 648 | self.streamed_state == StreamedState::Exhausted; |
728 | 648 | let buffered_exhausted = |
729 | 648 | self.buffered_state == BufferedState::Exhausted; |
730 | 648 | self.state = if streamed_exhausted && buffered_exhausted81 { |
731 | 0 | SMJState::Exhausted |
732 | | } else { |
733 | 648 | match self.current_ordering { |
734 | | Ordering::Less | Ordering::Equal => { |
735 | 437 | if !streamed_exhausted { |
736 | 437 | self.streamed_joined = false; |
737 | 437 | self.streamed_state = StreamedState::Init; |
738 | 437 | }0 |
739 | | } |
740 | | Ordering::Greater => { |
741 | 211 | if !buffered_exhausted { |
742 | 211 | self.buffered_joined = false; |
743 | 211 | self.buffered_state = BufferedState::Init; |
744 | 211 | }0 |
745 | | } |
746 | | } |
747 | 648 | SMJState::Polling |
748 | | }; |
749 | | } |
750 | | SMJState::Polling => { |
751 | 648 | if ![StreamedState::Exhausted, StreamedState::Ready] |
752 | 648 | .contains(&self.streamed_state) |
753 | | { |
754 | 437 | match self.poll_streamed_row(cx)?0 { |
755 | 437 | Poll::Ready(_) => {} |
756 | 0 | Poll::Pending => return Poll::Pending, |
757 | | } |
758 | 211 | } |
759 | | |
760 | 648 | if ![BufferedState::Exhausted, BufferedState::Ready] |
761 | 648 | .contains(&self.buffered_state) |
762 | | { |
763 | 290 | match self.poll_buffered_batches(cx)?12 { |
764 | 278 | Poll::Ready(_) => {} |
765 | 0 | Poll::Pending => return Poll::Pending, |
766 | | } |
767 | 358 | } |
768 | 636 | let streamed_exhausted = |
769 | 636 | self.streamed_state == StreamedState::Exhausted; |
770 | 636 | let buffered_exhausted = |
771 | 636 | self.buffered_state == BufferedState::Exhausted; |
772 | 636 | if streamed_exhausted && buffered_exhausted148 { |
773 | 67 | self.state = SMJState::Exhausted; |
774 | 67 | continue; |
775 | 569 | } |
776 | 569 | self.current_ordering = self.compare_streamed_buffered()?0 ; |
777 | 569 | self.state = SMJState::JoinOutput; |
778 | | } |
779 | | SMJState::JoinOutput => { |
780 | 878 | self.join_partial()?0 ; |
781 | | |
782 | 878 | if self.output_size < self.batch_size { |
783 | 569 | if self.buffered_data.scanning_finished() { |
784 | 569 | self.buffered_data.scanning_reset(); |
785 | 569 | self.state = SMJState::Init; |
786 | 569 | }0 |
787 | | } else { |
788 | 309 | self.freeze_all()?0 ; |
789 | 309 | if !self.output_record_batches.is_empty() { |
790 | 309 | let record_batch = self.output_record_batch_and_reset()?0 ; |
791 | 309 | return Poll::Ready(Some(Ok(record_batch))); |
792 | 0 | } |
793 | 0 | return Poll::Pending; |
794 | | } |
795 | | } |
796 | | SMJState::Exhausted => { |
797 | 108 | self.freeze_all()?0 ; |
798 | 108 | if !self.output_record_batches.is_empty() { |
799 | 41 | let record_batch = self.output_record_batch_and_reset()?0 ; |
800 | 41 | return Poll::Ready(Some(Ok(record_batch))); |
801 | 67 | } |
802 | 67 | return Poll::Ready(None); |
803 | | } |
804 | | } |
805 | | } |
806 | 429 | } |
807 | | } |
808 | | |
809 | | impl SMJStream { |
810 | | #[allow(clippy::too_many_arguments)] |
811 | 79 | pub fn try_new( |
812 | 79 | schema: SchemaRef, |
813 | 79 | sort_options: Vec<SortOptions>, |
814 | 79 | null_equals_null: bool, |
815 | 79 | streamed: SendableRecordBatchStream, |
816 | 79 | buffered: SendableRecordBatchStream, |
817 | 79 | on_streamed: Vec<Arc<dyn PhysicalExpr>>, |
818 | 79 | on_buffered: Vec<Arc<dyn PhysicalExpr>>, |
819 | 79 | filter: Option<JoinFilter>, |
820 | 79 | join_type: JoinType, |
821 | 79 | batch_size: usize, |
822 | 79 | join_metrics: SortMergeJoinMetrics, |
823 | 79 | reservation: MemoryReservation, |
824 | 79 | runtime_env: Arc<RuntimeEnv>, |
825 | 79 | ) -> Result<Self> { |
826 | 79 | let streamed_schema = streamed.schema(); |
827 | 79 | let buffered_schema = buffered.schema(); |
828 | 79 | Ok(Self { |
829 | 79 | state: SMJState::Init, |
830 | 79 | sort_options, |
831 | 79 | null_equals_null, |
832 | 79 | schema, |
833 | 79 | streamed_schema: Arc::clone(&streamed_schema), |
834 | 79 | buffered_schema, |
835 | 79 | streamed, |
836 | 79 | buffered, |
837 | 79 | streamed_batch: StreamedBatch::new_empty(streamed_schema), |
838 | 79 | buffered_data: BufferedData::default(), |
839 | 79 | streamed_joined: false, |
840 | 79 | buffered_joined: false, |
841 | 79 | streamed_state: StreamedState::Init, |
842 | 79 | buffered_state: BufferedState::Init, |
843 | 79 | current_ordering: Ordering::Equal, |
844 | 79 | on_streamed, |
845 | 79 | on_buffered, |
846 | 79 | filter, |
847 | 79 | output_record_batches: vec![], |
848 | 79 | output_size: 0, |
849 | 79 | batch_size, |
850 | 79 | join_type, |
851 | 79 | join_metrics, |
852 | 79 | reservation, |
853 | 79 | runtime_env, |
854 | 79 | }) |
855 | 79 | } |
856 | | |
857 | | /// Poll next streamed row |
858 | 437 | fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> { |
859 | | loop { |
860 | 831 | match &self.streamed_state { |
861 | | StreamedState::Init => { |
862 | 437 | if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() |
863 | | { |
864 | 240 | self.streamed_batch.idx += 1; |
865 | 240 | self.streamed_state = StreamedState::Ready; |
866 | 240 | return Poll::Ready(Some(Ok(()))); |
867 | 197 | } else { |
868 | 197 | self.streamed_state = StreamedState::Polling; |
869 | 197 | } |
870 | | } |
871 | 197 | StreamedState::Polling => match self.streamed.poll_next_unpin(cx)?0 { |
872 | | Poll::Pending => { |
873 | 0 | return Poll::Pending; |
874 | | } |
875 | 67 | Poll::Ready(None) => { |
876 | 67 | self.streamed_state = StreamedState::Exhausted; |
877 | 67 | } |
878 | 130 | Poll::Ready(Some(batch)) => { |
879 | 130 | if batch.num_rows() > 0 { |
880 | 130 | self.freeze_streamed()?0 ; |
881 | 130 | self.join_metrics.input_batches.add(1); |
882 | 130 | self.join_metrics.input_rows.add(batch.num_rows()); |
883 | 130 | self.streamed_batch = |
884 | 130 | StreamedBatch::new(batch, &self.on_streamed); |
885 | 130 | self.streamed_state = StreamedState::Ready; |
886 | 0 | } |
887 | | } |
888 | | }, |
889 | | StreamedState::Ready => { |
890 | 130 | return Poll::Ready(Some(Ok(()))); |
891 | | } |
892 | | StreamedState::Exhausted => { |
893 | 67 | return Poll::Ready(None); |
894 | | } |
895 | | } |
896 | | } |
897 | 437 | } |
898 | | |
899 | 118 | fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { |
900 | 118 | // Shrink memory usage for in-memory batches only |
901 | 118 | if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some()82 { |
902 | 82 | self.reservation |
903 | 82 | .try_shrink(buffered_batch.size_estimation)?0 ; |
904 | 36 | } |
905 | | |
906 | 118 | Ok(()) |
907 | 118 | } |
908 | | |
909 | 130 | fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { |
910 | 130 | match self.reservation.try_grow(buffered_batch.size_estimation) { |
911 | | Ok(_) => { |
912 | 82 | self.join_metrics |
913 | 82 | .peak_mem_used |
914 | 82 | .set_max(self.reservation.size()); |
915 | 82 | Ok(()) |
916 | | } |
917 | 48 | Err(_) if self.runtime_env.disk_manager.tmp_files_enabled()36 => { |
918 | | // spill buffered batch to disk |
919 | 36 | let spill_file = self |
920 | 36 | .runtime_env |
921 | 36 | .disk_manager |
922 | 36 | .create_tmp_file("sort_merge_join_buffered_spill")?0 ; |
923 | | |
924 | 36 | if let Some(batch) = buffered_batch.batch { |
925 | 36 | spill_record_batches( |
926 | 36 | vec![batch], |
927 | 36 | spill_file.path().into(), |
928 | 36 | Arc::clone(&self.buffered_schema), |
929 | 36 | )?0 ; |
930 | 36 | buffered_batch.spill_file = Some(spill_file); |
931 | 36 | buffered_batch.batch = None; |
932 | 36 | |
933 | 36 | // update metrics to register spill |
934 | 36 | self.join_metrics.spill_count.add(1); |
935 | 36 | self.join_metrics |
936 | 36 | .spilled_bytes |
937 | 36 | .add(buffered_batch.size_estimation); |
938 | 36 | self.join_metrics.spilled_rows.add(buffered_batch.num_rows); |
939 | 36 | Ok(()) |
940 | | } else { |
941 | 0 | internal_err!("Buffered batch has empty body") |
942 | | } |
943 | | } |
944 | 12 | Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), |
945 | 12 | }?; |
946 | | |
947 | 118 | self.buffered_data.batches.push_back(buffered_batch); |
948 | 118 | Ok(()) |
949 | 130 | } |
950 | | |
951 | | /// Poll next buffered batches |
952 | 290 | fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> { |
953 | | loop { |
954 | 837 | match &self.buffered_state { |
955 | | BufferedState::Init => { |
956 | | // pop previous buffered batches |
957 | 408 | while !self.buffered_data.batches.is_empty() { |
958 | 262 | let head_batch = self.buffered_data.head_batch(); |
959 | 262 | // If the head batch is fully processed, dequeue it and produce output of it. |
960 | 262 | if head_batch.range.end == head_batch.num_rows { |
961 | 118 | self.freeze_dequeuing_buffered()?0 ; |
962 | 118 | if let Some(buffered_batch) = |
963 | 118 | self.buffered_data.batches.pop_front() |
964 | | { |
965 | 118 | self.free_reservation(buffered_batch)?0 ; |
966 | 0 | } |
967 | | } else { |
968 | | // If the head batch is not fully processed, break the loop. |
969 | | // Streamed batch will be joined with the head batch in the next step. |
970 | 144 | break; |
971 | | } |
972 | | } |
973 | 290 | if self.buffered_data.batches.is_empty() { |
974 | 146 | self.buffered_state = BufferedState::PollingFirst; |
975 | 146 | } else { |
976 | 144 | let tail_batch = self.buffered_data.tail_batch_mut(); |
977 | 144 | tail_batch.range.start = tail_batch.range.end; |
978 | 144 | tail_batch.range.end += 1; |
979 | 144 | self.buffered_state = BufferedState::PollingRest; |
980 | 144 | } |
981 | | } |
982 | 146 | BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)?0 { |
983 | | Poll::Pending => { |
984 | 0 | return Poll::Pending; |
985 | | } |
986 | | Poll::Ready(None) => { |
987 | 67 | self.buffered_state = BufferedState::Exhausted; |
988 | 67 | return Poll::Ready(None); |
989 | | } |
990 | 79 | Poll::Ready(Some(batch)) => { |
991 | 79 | self.join_metrics.input_batches.add(1); |
992 | 79 | self.join_metrics.input_rows.add(batch.num_rows()); |
993 | 79 | |
994 | 79 | if batch.num_rows() > 0 { |
995 | 79 | let buffered_batch = |
996 | 79 | BufferedBatch::new(batch, 0..1, &self.on_buffered); |
997 | 79 | |
998 | 79 | self.allocate_reservation(buffered_batch)?12 ; |
999 | 67 | self.buffered_state = BufferedState::PollingRest; |
1000 | 0 | } |
1001 | | } |
1002 | | }, |
1003 | | BufferedState::PollingRest => { |
1004 | 334 | if self.buffered_data.tail_batch().range.end |
1005 | 334 | < self.buffered_data.tail_batch().num_rows |
1006 | | { |
1007 | 321 | while self.buffered_data.tail_batch().range.end |
1008 | 321 | < self.buffered_data.tail_batch().num_rows |
1009 | | { |
1010 | 249 | if is_join_arrays_equal( |
1011 | 249 | &self.buffered_data.head_batch().join_arrays, |
1012 | 249 | self.buffered_data.head_batch().range.start, |
1013 | 249 | &self.buffered_data.tail_batch().join_arrays, |
1014 | 249 | self.buffered_data.tail_batch().range.end, |
1015 | 249 | )?0 { |
1016 | 105 | self.buffered_data.tail_batch_mut().range.end += 1; |
1017 | 105 | } else { |
1018 | 144 | self.buffered_state = BufferedState::Ready; |
1019 | 144 | return Poll::Ready(Some(Ok(()))); |
1020 | | } |
1021 | | } |
1022 | | } else { |
1023 | 118 | match self.buffered.poll_next_unpin(cx)?0 { |
1024 | | Poll::Pending => { |
1025 | 0 | return Poll::Pending; |
1026 | | } |
1027 | 67 | Poll::Ready(None) => { |
1028 | 67 | self.buffered_state = BufferedState::Ready; |
1029 | 67 | } |
1030 | 51 | Poll::Ready(Some(batch)) => { |
1031 | 51 | // Polling batches coming concurrently as multiple partitions |
1032 | 51 | self.join_metrics.input_batches.add(1); |
1033 | 51 | self.join_metrics.input_rows.add(batch.num_rows()); |
1034 | 51 | if batch.num_rows() > 0 { |
1035 | 51 | let buffered_batch = BufferedBatch::new( |
1036 | 51 | batch, |
1037 | 51 | 0..0, |
1038 | 51 | &self.on_buffered, |
1039 | 51 | ); |
1040 | 51 | self.allocate_reservation(buffered_batch)?0 ; |
1041 | 0 | } |
1042 | | } |
1043 | | } |
1044 | | } |
1045 | | } |
1046 | | BufferedState::Ready => { |
1047 | 67 | return Poll::Ready(Some(Ok(()))); |
1048 | | } |
1049 | | BufferedState::Exhausted => { |
1050 | 0 | return Poll::Ready(None); |
1051 | | } |
1052 | | } |
1053 | | } |
1054 | 290 | } |
1055 | | |
1056 | | /// Get comparison result of streamed row and buffered batches |
1057 | 569 | fn compare_streamed_buffered(&self) -> Result<Ordering> { |
1058 | 569 | if self.streamed_state == StreamedState::Exhausted { |
1059 | 81 | return Ok(Ordering::Greater); |
1060 | 488 | } |
1061 | 488 | if !self.buffered_data.has_buffered_rows() { |
1062 | 12 | return Ok(Ordering::Less); |
1063 | 476 | } |
1064 | 476 | |
1065 | 476 | return compare_join_arrays( |
1066 | 476 | &self.streamed_batch.join_arrays, |
1067 | 476 | self.streamed_batch.idx, |
1068 | 476 | &self.buffered_data.head_batch().join_arrays, |
1069 | 476 | self.buffered_data.head_batch().range.start, |
1070 | 476 | &self.sort_options, |
1071 | 476 | self.null_equals_null, |
1072 | 476 | ); |
1073 | 569 | } |
1074 | | |
1075 | | /// Produce join and fill output buffer until reaching target batch size |
1076 | | /// or the join is finished |
1077 | 878 | fn join_partial(&mut self) -> Result<()> { |
1078 | 878 | // Whether to join streamed rows |
1079 | 878 | let mut join_streamed = false; |
1080 | 878 | // Whether to join buffered rows |
1081 | 878 | let mut join_buffered = false; |
1082 | 878 | |
1083 | 878 | // determine whether we need to join streamed/buffered rows |
1084 | 878 | match self.current_ordering { |
1085 | | Ordering::Less => { |
1086 | 20 | if matches!( |
1087 | 84 | self.join_type, |
1088 | | JoinType::Left |
1089 | | | JoinType::Right |
1090 | | | JoinType::RightSemi |
1091 | | | JoinType::Full |
1092 | | | JoinType::LeftAnti |
1093 | 64 | ) { |
1094 | 64 | join_streamed = !self.streamed_joined; |
1095 | 64 | }20 |
1096 | | } |
1097 | | Ordering::Equal => { |
1098 | 581 | if matches!518 (self.join_type, JoinType::LeftSemi) { |
1099 | | // if the join filter is specified then its needed to output the streamed index |
1100 | | // only if it has not been emitted before |
1101 | | // the `join_filter_matched_idxs` keeps track on if streamed index has a successful |
1102 | | // filter match and prevents the same index to go into output more than once |
1103 | 63 | if self.filter.is_some() { |
1104 | 0 | join_streamed = !self |
1105 | 0 | .streamed_batch |
1106 | 0 | .join_filter_matched_idxs |
1107 | 0 | .contains(&(self.streamed_batch.idx as u64)) |
1108 | 0 | && !self.streamed_joined; |
1109 | | // if the join filter specified there can be references to buffered columns |
1110 | | // so buffered columns are needed to access them |
1111 | 0 | join_buffered = join_streamed; |
1112 | 63 | } else { |
1113 | 63 | join_streamed = !self.streamed_joined; |
1114 | 63 | } |
1115 | 518 | } |
1116 | 106 | if matches!( |
1117 | 581 | self.join_type, |
1118 | | JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full |
1119 | 475 | ) { |
1120 | 475 | join_streamed = true; |
1121 | 475 | join_buffered = true; |
1122 | 475 | };106 |
1123 | | |
1124 | 581 | if matches!538 (self.join_type, JoinType::LeftAnti) && self.filter.is_some()43 { |
1125 | 0 | join_streamed = !self |
1126 | 0 | .streamed_batch |
1127 | 0 | .join_filter_matched_idxs |
1128 | 0 | .contains(&(self.streamed_batch.idx as u64)) |
1129 | 0 | && !self.streamed_joined; |
1130 | 0 | join_buffered = join_streamed; |
1131 | 581 | } |
1132 | | } |
1133 | | Ordering::Greater => { |
1134 | 213 | if matches!180 (self.join_type, JoinType::Full) { |
1135 | 33 | join_buffered = !self.buffered_joined; |
1136 | 180 | }; |
1137 | | } |
1138 | | } |
1139 | 878 | if !join_streamed && !join_buffered310 { |
1140 | | // no joined data |
1141 | 303 | self.buffered_data.scanning_finish(); |
1142 | 303 | return Ok(()); |
1143 | 575 | } |
1144 | 575 | |
1145 | 575 | if join_buffered { |
1146 | | // joining streamed/nulls and buffered |
1147 | 1.08k | while !self.buffered_data.scanning_finished() |
1148 | 801 | && self.output_size < self.batch_size |
1149 | | { |
1150 | 607 | let scanning_idx = self.buffered_data.scanning_idx(); |
1151 | 607 | if join_streamed { |
1152 | 600 | // Join streamed row and buffered row |
1153 | 600 | self.streamed_batch.append_output_pair( |
1154 | 600 | Some(self.buffered_data.scanning_batch_idx), |
1155 | 600 | Some(scanning_idx), |
1156 | 600 | ); |
1157 | 600 | } else { |
1158 | 7 | // Join nulls and buffered row for FULL join |
1159 | 7 | self.buffered_data |
1160 | 7 | .scanning_batch_mut() |
1161 | 7 | .null_joined |
1162 | 7 | .push(scanning_idx); |
1163 | 7 | } |
1164 | 607 | self.output_size += 1; |
1165 | 607 | self.buffered_data.scanning_advance(); |
1166 | 607 | |
1167 | 607 | if self.buffered_data.scanning_finished() { |
1168 | 209 | self.streamed_joined = join_streamed; |
1169 | 209 | self.buffered_joined = true; |
1170 | 398 | } |
1171 | | } |
1172 | | } else { |
1173 | | // joining streamed and nulls |
1174 | 93 | let scanning_batch_idx = if self.buffered_data.scanning_finished() { |
1175 | 11 | None |
1176 | | } else { |
1177 | 82 | Some(self.buffered_data.scanning_batch_idx) |
1178 | | }; |
1179 | | |
1180 | 93 | self.streamed_batch |
1181 | 93 | .append_output_pair(scanning_batch_idx, None); |
1182 | 93 | self.output_size += 1; |
1183 | 93 | self.buffered_data.scanning_finish(); |
1184 | 93 | self.streamed_joined = true; |
1185 | | } |
1186 | 575 | Ok(()) |
1187 | 878 | } |
1188 | | |
1189 | 417 | fn freeze_all(&mut self) -> Result<()> { |
1190 | 417 | self.freeze_streamed()?0 ; |
1191 | 417 | self.freeze_buffered(self.buffered_data.batches.len(), false)?0 ; |
1192 | 417 | Ok(()) |
1193 | 417 | } |
1194 | | |
1195 | | // Produces and stages record batches to ensure dequeued buffered batch |
1196 | | // no longer needed: |
1197 | | // 1. freezes all indices joined to streamed side |
1198 | | // 2. freezes NULLs joined to dequeued buffered batch to "release" it |
1199 | 118 | fn freeze_dequeuing_buffered(&mut self) -> Result<()> { |
1200 | 118 | self.freeze_streamed()?0 ; |
1201 | | // Only freeze and produce the first batch in buffered_data as the batch is fully processed |
1202 | 118 | self.freeze_buffered(1, true)?0 ; |
1203 | 118 | Ok(()) |
1204 | 118 | } |
1205 | | |
1206 | | // Produces and stages record batch from buffered indices with corresponding |
1207 | | // NULLs on streamed side. |
1208 | | // |
1209 | | // Applicable only in case of Full join. |
1210 | | // |
1211 | | // If `output_not_matched_filter` is true, this will also produce record batches |
1212 | | // for buffered rows which are joined with streamed side but don't match join filter. |
1213 | 535 | fn freeze_buffered( |
1214 | 535 | &mut self, |
1215 | 535 | batch_count: usize, |
1216 | 535 | output_not_matched_filter: bool, |
1217 | 535 | ) -> Result<()> { |
1218 | 535 | if !matches!426 (self.join_type, JoinType::Full) { |
1219 | 426 | return Ok(()); |
1220 | 109 | } |
1221 | 213 | for buffered_batch in self.buffered_data.batches.range_mut(..batch_count)109 { |
1222 | 213 | let buffered_indices = UInt64Array::from_iter_values( |
1223 | 213 | buffered_batch.null_joined.iter().map(|&index| index as u647 ), |
1224 | 213 | ); |
1225 | 213 | if let Some(record_batch7 ) = produce_buffered_null_batch( |
1226 | 213 | &self.schema, |
1227 | 213 | &self.streamed_schema, |
1228 | 213 | &buffered_indices, |
1229 | 213 | buffered_batch, |
1230 | 213 | )?0 { |
1231 | 7 | self.output_record_batches.push(record_batch); |
1232 | 206 | } |
1233 | 213 | buffered_batch.null_joined.clear(); |
1234 | 213 | |
1235 | 213 | // For buffered row which is joined with streamed side rows but all joined rows |
1236 | 213 | // don't satisfy the join filter |
1237 | 213 | if output_not_matched_filter { |
1238 | 19 | let not_matched_buffered_indices = buffered_batch |
1239 | 19 | .join_filter_failed_map |
1240 | 19 | .iter() |
1241 | 19 | .filter_map(|(idx, failed)| if *failed0 { Some(*idx)0 } else { None0 }0 ) |
1242 | 19 | .collect::<Vec<_>>(); |
1243 | 19 | |
1244 | 19 | let buffered_indices = UInt64Array::from_iter_values( |
1245 | 19 | not_matched_buffered_indices.iter().copied(), |
1246 | 19 | ); |
1247 | | |
1248 | 19 | if let Some(record_batch0 ) = produce_buffered_null_batch( |
1249 | 19 | &self.schema, |
1250 | 19 | &self.streamed_schema, |
1251 | 19 | &buffered_indices, |
1252 | 19 | buffered_batch, |
1253 | 19 | )?0 { |
1254 | 0 | self.output_record_batches.push(record_batch); |
1255 | 19 | } |
1256 | 19 | buffered_batch.join_filter_failed_map.clear(); |
1257 | 194 | } |
1258 | | } |
1259 | 109 | Ok(()) |
1260 | 535 | } |
1261 | | |
1262 | | // Produces and stages record batch for all output indices found |
1263 | | // for current streamed batch and clears staged output indices. |
1264 | 665 | fn freeze_streamed(&mut self) -> Result<()> { |
1265 | 665 | for chunk504 in self.streamed_batch.output_indices.iter_mut() { |
1266 | | // The row indices of joined streamed batch |
1267 | 504 | let streamed_indices = chunk.streamed_indices.finish(); |
1268 | 504 | |
1269 | 504 | if streamed_indices.is_empty() { |
1270 | 0 | continue; |
1271 | 504 | } |
1272 | | |
1273 | 504 | let mut streamed_columns = self |
1274 | 504 | .streamed_batch |
1275 | 504 | .batch |
1276 | 504 | .columns() |
1277 | 504 | .iter() |
1278 | 1.51k | .map(|column| take(column, &streamed_indices, None)) |
1279 | 504 | .collect::<Result<Vec<_>, ArrowError>>()?0 ; |
1280 | | |
1281 | | // The row indices of joined buffered batch |
1282 | 504 | let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); |
1283 | 504 | let mut buffered_columns = |
1284 | 504 | if matches!468 (self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { |
1285 | 36 | vec![] |
1286 | 468 | } else if let Some(buffered_idx459 ) = chunk.buffered_batch_idx { |
1287 | 459 | get_buffered_columns( |
1288 | 459 | &self.buffered_data, |
1289 | 459 | buffered_idx, |
1290 | 459 | &buffered_indices, |
1291 | 459 | )?0 |
1292 | | } else { |
1293 | | // If buffered batch none, meaning it is null joined batch. |
1294 | | // We need to create null arrays for buffered columns to join with streamed rows. |
1295 | 9 | self.buffered_schema |
1296 | 9 | .fields() |
1297 | 9 | .iter() |
1298 | 27 | .map(|f| new_null_array(f.data_type(), buffered_indices.len())) |
1299 | 9 | .collect::<Vec<_>>() |
1300 | | }; |
1301 | | |
1302 | 504 | let streamed_columns_length = streamed_columns.len(); |
1303 | 504 | let buffered_columns_length = buffered_columns.len(); |
1304 | | |
1305 | | // Prepare the columns we apply join filter on later. |
1306 | | // Only for joined rows between streamed and buffered. |
1307 | 504 | let filter_columns = if chunk.buffered_batch_idx.is_some() { |
1308 | 494 | if matches!386 (self.join_type, JoinType::Right) { |
1309 | 108 | get_filter_column(&self.filter, &buffered_columns, &streamed_columns) |
1310 | 351 | } else if matches!( |
1311 | 386 | self.join_type, |
1312 | | JoinType::LeftSemi | JoinType::LeftAnti |
1313 | | ) { |
1314 | | // unwrap is safe here as we check is_some on top of if statement |
1315 | 35 | let buffered_columns = get_buffered_columns( |
1316 | 35 | &self.buffered_data, |
1317 | 35 | chunk.buffered_batch_idx.unwrap(), |
1318 | 35 | &buffered_indices, |
1319 | 35 | )?0 ; |
1320 | | |
1321 | 35 | get_filter_column(&self.filter, &streamed_columns, &buffered_columns) |
1322 | | } else { |
1323 | 351 | get_filter_column(&self.filter, &streamed_columns, &buffered_columns) |
1324 | | } |
1325 | | } else { |
1326 | | // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. |
1327 | | // Any join filter applied only on either streamed or buffered side will be pushed already. |
1328 | 10 | vec![] |
1329 | | }; |
1330 | | |
1331 | 504 | let columns = if matches!391 (self.join_type, JoinType::Right) { |
1332 | 113 | buffered_columns.extend(streamed_columns.clone()); |
1333 | 113 | buffered_columns |
1334 | | } else { |
1335 | 391 | streamed_columns.extend(buffered_columns); |
1336 | 391 | streamed_columns |
1337 | | }; |
1338 | | |
1339 | 504 | let output_batch = |
1340 | 504 | RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?0 ; |
1341 | | |
1342 | | // Apply join filter if any |
1343 | 504 | if !filter_columns.is_empty() { |
1344 | 0 | if let Some(f) = &self.filter { |
1345 | | // Construct batch with only filter columns |
1346 | 0 | let filter_batch = RecordBatch::try_new( |
1347 | 0 | Arc::new(f.schema().clone()), |
1348 | 0 | filter_columns, |
1349 | 0 | )?; |
1350 | | |
1351 | 0 | let filter_result = f |
1352 | 0 | .expression() |
1353 | 0 | .evaluate(&filter_batch)? |
1354 | 0 | .into_array(filter_batch.num_rows())?; |
1355 | | |
1356 | | // The boolean selection mask of the join filter result |
1357 | 0 | let pre_mask = |
1358 | 0 | datafusion_common::cast::as_boolean_array(&filter_result)?; |
1359 | | |
1360 | | // If there are nulls in join filter result, exclude them from selecting |
1361 | | // the rows to output. |
1362 | 0 | let mask = if pre_mask.null_count() > 0 { |
1363 | | compute::prep_null_mask_filter( |
1364 | 0 | datafusion_common::cast::as_boolean_array(&filter_result)?, |
1365 | | ) |
1366 | | } else { |
1367 | 0 | pre_mask.clone() |
1368 | | }; |
1369 | | |
1370 | | // For certain join types, we need to adjust the initial mask to handle the join filter. |
1371 | 0 | let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> = |
1372 | 0 | get_filtered_join_mask( |
1373 | 0 | self.join_type, |
1374 | 0 | &streamed_indices, |
1375 | 0 | &mask, |
1376 | 0 | &self.streamed_batch.join_filter_matched_idxs, |
1377 | 0 | &self.buffered_data.scanning_offset, |
1378 | 0 | ); |
1379 | | |
1380 | 0 | let mask = |
1381 | 0 | if let Some(ref filtered_join_mask) = maybe_filtered_join_mask { |
1382 | 0 | self.streamed_batch |
1383 | 0 | .join_filter_matched_idxs |
1384 | 0 | .extend(&filtered_join_mask.1); |
1385 | 0 | &filtered_join_mask.0 |
1386 | | } else { |
1387 | 0 | &mask |
1388 | | }; |
1389 | | |
1390 | | // Push the filtered batch which contains rows passing join filter to the output |
1391 | 0 | let filtered_batch = |
1392 | 0 | compute::filter_record_batch(&output_batch, mask)?; |
1393 | 0 | self.output_record_batches.push(filtered_batch); |
1394 | | |
1395 | | // For outer joins, we need to push the null joined rows to the output if |
1396 | | // all joined rows are failed on the join filter. |
1397 | | // I.e., if all rows joined from a streamed row are failed with the join filter, |
1398 | | // we need to join it with nulls as buffered side. |
1399 | 0 | if matches!( |
1400 | 0 | self.join_type, |
1401 | | JoinType::Left | JoinType::Right | JoinType::Full |
1402 | | ) { |
1403 | | // We need to get the mask for row indices that the joined rows are failed |
1404 | | // on the join filter. I.e., for a row in streamed side, if all joined rows |
1405 | | // between it and all buffered rows are failed on the join filter, we need to |
1406 | | // output it with null columns from buffered side. For the mask here, it |
1407 | | // behaves like LeftAnti join. |
1408 | 0 | let null_mask: BooleanArray = get_filtered_join_mask( |
1409 | 0 | // Set a mask slot as true only if all joined rows of same streamed index |
1410 | 0 | // are failed on the join filter. |
1411 | 0 | // The masking behavior is like LeftAnti join. |
1412 | 0 | JoinType::LeftAnti, |
1413 | 0 | &streamed_indices, |
1414 | 0 | mask, |
1415 | 0 | &self.streamed_batch.join_filter_matched_idxs, |
1416 | 0 | &self.buffered_data.scanning_offset, |
1417 | 0 | ) |
1418 | 0 | .unwrap() |
1419 | 0 | .0; |
1420 | | |
1421 | 0 | let null_joined_batch = |
1422 | 0 | compute::filter_record_batch(&output_batch, &null_mask)?; |
1423 | | |
1424 | 0 | let mut buffered_columns = self |
1425 | 0 | .buffered_schema |
1426 | 0 | .fields() |
1427 | 0 | .iter() |
1428 | 0 | .map(|f| { |
1429 | 0 | new_null_array( |
1430 | 0 | f.data_type(), |
1431 | 0 | null_joined_batch.num_rows(), |
1432 | 0 | ) |
1433 | 0 | }) |
1434 | 0 | .collect::<Vec<_>>(); |
1435 | | |
1436 | 0 | let columns = if matches!(self.join_type, JoinType::Right) { |
1437 | 0 | let streamed_columns = null_joined_batch |
1438 | 0 | .columns() |
1439 | 0 | .iter() |
1440 | 0 | .skip(buffered_columns_length) |
1441 | 0 | .cloned() |
1442 | 0 | .collect::<Vec<_>>(); |
1443 | 0 |
|
1444 | 0 | buffered_columns.extend(streamed_columns); |
1445 | 0 | buffered_columns |
1446 | | } else { |
1447 | | // Left join or full outer join |
1448 | 0 | let mut streamed_columns = null_joined_batch |
1449 | 0 | .columns() |
1450 | 0 | .iter() |
1451 | 0 | .take(streamed_columns_length) |
1452 | 0 | .cloned() |
1453 | 0 | .collect::<Vec<_>>(); |
1454 | 0 |
|
1455 | 0 | streamed_columns.extend(buffered_columns); |
1456 | 0 | streamed_columns |
1457 | | }; |
1458 | | |
1459 | | // Push the streamed/buffered batch joined nulls to the output |
1460 | 0 | let null_joined_streamed_batch = RecordBatch::try_new( |
1461 | 0 | Arc::clone(&self.schema), |
1462 | 0 | columns.clone(), |
1463 | 0 | )?; |
1464 | 0 | self.output_record_batches.push(null_joined_streamed_batch); |
1465 | | |
1466 | | // For full join, we also need to output the null joined rows from the buffered side. |
1467 | | // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with |
1468 | | // streamed side, it won't be outputted by `freeze_buffered`. |
1469 | | // We need to check if a buffered row is joined with streamed side and output. |
1470 | | // If it is joined with streamed side, but doesn't match the join filter, |
1471 | | // we need to output it with nulls as streamed side. |
1472 | 0 | if matches!(self.join_type, JoinType::Full) { |
1473 | 0 | let buffered_batch = &mut self.buffered_data.batches |
1474 | 0 | [chunk.buffered_batch_idx.unwrap()]; |
1475 | | |
1476 | 0 | for i in 0..pre_mask.len() { |
1477 | | // If the buffered row is not joined with streamed side, |
1478 | | // skip it. |
1479 | 0 | if buffered_indices.is_null(i) { |
1480 | 0 | continue; |
1481 | 0 | } |
1482 | 0 |
|
1483 | 0 | let buffered_index = buffered_indices.value(i); |
1484 | 0 |
|
1485 | 0 | buffered_batch.join_filter_failed_map.insert( |
1486 | 0 | buffered_index, |
1487 | 0 | *buffered_batch |
1488 | 0 | .join_filter_failed_map |
1489 | 0 | .get(&buffered_index) |
1490 | 0 | .unwrap_or(&true) |
1491 | 0 | && !pre_mask.value(i), |
1492 | | ); |
1493 | | } |
1494 | 0 | } |
1495 | 0 | } |
1496 | 0 | } else { |
1497 | 0 | self.output_record_batches.push(output_batch); |
1498 | 0 | } |
1499 | 504 | } else { |
1500 | 504 | self.output_record_batches.push(output_batch); |
1501 | 504 | } |
1502 | | } |
1503 | | |
1504 | 665 | self.streamed_batch.output_indices.clear(); |
1505 | 665 | |
1506 | 665 | Ok(()) |
1507 | 665 | } |
1508 | | |
1509 | 350 | fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> { |
1510 | 350 | let record_batch = concat_batches(&self.schema, &self.output_record_batches)?0 ; |
1511 | 350 | self.join_metrics.output_batches.add(1); |
1512 | 350 | self.join_metrics.output_rows.add(record_batch.num_rows()); |
1513 | 350 | // If join filter exists, `self.output_size` is not accurate as we don't know the exact |
1514 | 350 | // number of rows in the output record batch. If streamed row joined with buffered rows, |
1515 | 350 | // once join filter is applied, the number of output rows may be more than 1. |
1516 | 350 | // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened |
1517 | 350 | // when the join filter is applied and all rows are filtered out. |
1518 | 350 | if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { |
1519 | 0 | self.output_size = 0; |
1520 | 350 | } else { |
1521 | 350 | self.output_size -= record_batch.num_rows(); |
1522 | 350 | } |
1523 | 350 | self.output_record_batches.clear(); |
1524 | 350 | Ok(record_batch) |
1525 | 350 | } |
1526 | | } |
1527 | | |
1528 | | /// Gets the arrays which join filters are applied on. |
1529 | 494 | fn get_filter_column( |
1530 | 494 | join_filter: &Option<JoinFilter>, |
1531 | 494 | streamed_columns: &[ArrayRef], |
1532 | 494 | buffered_columns: &[ArrayRef], |
1533 | 494 | ) -> Vec<ArrayRef> { |
1534 | 494 | let mut filter_columns = vec![]; |
1535 | | |
1536 | 494 | if let Some(f0 ) = join_filter { |
1537 | 0 | let left_columns = f |
1538 | 0 | .column_indices() |
1539 | 0 | .iter() |
1540 | 0 | .filter(|col_index| col_index.side == JoinSide::Left) |
1541 | 0 | .map(|i| Arc::clone(&streamed_columns[i.index])) |
1542 | 0 | .collect::<Vec<_>>(); |
1543 | 0 |
|
1544 | 0 | let right_columns = f |
1545 | 0 | .column_indices() |
1546 | 0 | .iter() |
1547 | 0 | .filter(|col_index| col_index.side == JoinSide::Right) |
1548 | 0 | .map(|i| Arc::clone(&buffered_columns[i.index])) |
1549 | 0 | .collect::<Vec<_>>(); |
1550 | 0 |
|
1551 | 0 | filter_columns.extend(left_columns); |
1552 | 0 | filter_columns.extend(right_columns); |
1553 | 494 | } |
1554 | | |
1555 | 494 | filter_columns |
1556 | 494 | } |
1557 | | |
1558 | 232 | fn produce_buffered_null_batch( |
1559 | 232 | schema: &SchemaRef, |
1560 | 232 | streamed_schema: &SchemaRef, |
1561 | 232 | buffered_indices: &PrimitiveArray<UInt64Type>, |
1562 | 232 | buffered_batch: &BufferedBatch, |
1563 | 232 | ) -> Result<Option<RecordBatch>> { |
1564 | 232 | if buffered_indices.is_empty() { |
1565 | 225 | return Ok(None); |
1566 | 7 | } |
1567 | | |
1568 | | // Take buffered (right) columns |
1569 | 7 | let buffered_columns = |
1570 | 7 | get_buffered_columns_from_batch(buffered_batch, buffered_indices)?0 ; |
1571 | | |
1572 | | // Create null streamed (left) columns |
1573 | 7 | let mut streamed_columns = streamed_schema |
1574 | 7 | .fields() |
1575 | 7 | .iter() |
1576 | 21 | .map(|f| new_null_array(f.data_type(), buffered_indices.len())) |
1577 | 7 | .collect::<Vec<_>>(); |
1578 | 7 | |
1579 | 7 | streamed_columns.extend(buffered_columns); |
1580 | 7 | |
1581 | 7 | Ok(Some(RecordBatch::try_new( |
1582 | 7 | Arc::clone(schema), |
1583 | 7 | streamed_columns, |
1584 | 7 | )?0 )) |
1585 | 232 | } |
1586 | | |
1587 | | /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` |
1588 | | #[inline(always)] |
1589 | 494 | fn get_buffered_columns( |
1590 | 494 | buffered_data: &BufferedData, |
1591 | 494 | buffered_batch_idx: usize, |
1592 | 494 | buffered_indices: &UInt64Array, |
1593 | 494 | ) -> Result<Vec<ArrayRef>> { |
1594 | 494 | get_buffered_columns_from_batch( |
1595 | 494 | &buffered_data.batches[buffered_batch_idx], |
1596 | 494 | buffered_indices, |
1597 | 494 | ) |
1598 | 494 | } |
1599 | | |
1600 | | #[inline(always)] |
1601 | 501 | fn get_buffered_columns_from_batch( |
1602 | 501 | buffered_batch: &BufferedBatch, |
1603 | 501 | buffered_indices: &UInt64Array, |
1604 | 501 | ) -> Result<Vec<ArrayRef>> { |
1605 | 501 | match (&buffered_batch.spill_file, &buffered_batch.batch) { |
1606 | | // In memory batch |
1607 | 347 | (None, Some(batch)) => Ok(batch |
1608 | 347 | .columns() |
1609 | 347 | .iter() |
1610 | 1.04k | .map(|column| take(column, &buffered_indices, None)) |
1611 | 347 | .collect::<Result<Vec<_>, ArrowError>>() |
1612 | 347 | .map_err(Into::<DataFusionError>::into)?0 ), |
1613 | | // If the batch was spilled to disk, less likely |
1614 | 154 | (Some(spill_file), None) => { |
1615 | 154 | let mut buffered_cols: Vec<ArrayRef> = |
1616 | 154 | Vec::with_capacity(buffered_indices.len()); |
1617 | | |
1618 | 154 | let file = BufReader::new(File::open(spill_file.path())?0 ); |
1619 | 154 | let reader = FileReader::try_new(file, None)?0 ; |
1620 | | |
1621 | 308 | for batch154 in reader { |
1622 | 462 | batch154 ?0 .columns().iter().for_each(154 |column| { |
1623 | 462 | buffered_cols.extend(take(column, &buffered_indices, None)) |
1624 | 462 | }); |
1625 | 154 | } |
1626 | | |
1627 | 154 | Ok(buffered_cols) |
1628 | | } |
1629 | | // Invalid combination |
1630 | 0 | (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()), |
1631 | | } |
1632 | 501 | } |
1633 | | |
1634 | | /// Calculate join filter bit mask considering join type specifics |
1635 | | /// `streamed_indices` - array of streamed datasource JOINED row indices |
1636 | | /// `mask` - array booleans representing computed join filter expression eval result: |
1637 | | /// true = the row index matches the join filter |
1638 | | /// false = the row index doesn't match the join filter |
1639 | | /// `streamed_indices` have the same length as `mask` |
1640 | | /// `matched_indices` array of streaming indices that already has a join filter match |
1641 | | /// `scanning_buffered_offset` current buffered offset across batches |
1642 | | /// |
1643 | | /// This return a tuple of: |
1644 | | /// - corrected mask with respect to the join type |
1645 | | /// - indices of rows in streamed batch that have a join filter match |
1646 | 13 | fn get_filtered_join_mask( |
1647 | 13 | join_type: JoinType, |
1648 | 13 | streamed_indices: &UInt64Array, |
1649 | 13 | mask: &BooleanArray, |
1650 | 13 | matched_indices: &HashSet<u64>, |
1651 | 13 | scanning_buffered_offset: &usize, |
1652 | 13 | ) -> Option<(BooleanArray, Vec<u64>)> { |
1653 | 13 | let mut seen_as_true: bool = false; |
1654 | 13 | let streamed_indices_length = streamed_indices.len(); |
1655 | 13 | let mut corrected_mask: BooleanBuilder = |
1656 | 13 | BooleanBuilder::with_capacity(streamed_indices_length); |
1657 | 13 | |
1658 | 13 | let mut filter_matched_indices: Vec<u64> = vec![]; |
1659 | 13 | |
1660 | 13 | #[allow(clippy::needless_range_loop)] |
1661 | 13 | match join_type { |
1662 | | // for LeftSemi Join the filter mask should be calculated in its own way: |
1663 | | // if we find at least one matching row for specific streaming index |
1664 | | // we don't need to check any others for the same index |
1665 | | JoinType::LeftSemi => { |
1666 | | // have we seen a filter match for a streaming index before |
1667 | 28 | for i in 0..streamed_indices_length7 { |
1668 | | // LeftSemi respects only first true values for specific streaming index, |
1669 | | // others true values for the same index must be false |
1670 | 28 | let streamed_idx = streamed_indices.value(i); |
1671 | 28 | if mask.value(i) |
1672 | 14 | && !seen_as_true |
1673 | 10 | && !matched_indices.contains(&streamed_idx) |
1674 | 9 | { |
1675 | 9 | seen_as_true = true; |
1676 | 9 | corrected_mask.append_value(true); |
1677 | 9 | filter_matched_indices.push(streamed_idx); |
1678 | 19 | } else { |
1679 | 19 | corrected_mask.append_value(false); |
1680 | 19 | } |
1681 | | |
1682 | | // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag |
1683 | 28 | if i < streamed_indices_length - 1 |
1684 | 21 | && streamed_idx != streamed_indices.value(i + 1) |
1685 | 7 | { |
1686 | 7 | seen_as_true = false; |
1687 | 21 | } |
1688 | | } |
1689 | 7 | Some((corrected_mask.finish(), filter_matched_indices)) |
1690 | | } |
1691 | | // LeftAnti semantics: return true if for every x in the collection the join matching filter is false. |
1692 | | // `filter_matched_indices` needs to be set once per streaming index |
1693 | | // to prevent duplicates in the output |
1694 | | JoinType::LeftAnti => { |
1695 | | // have we seen a filter match for a streaming index before |
1696 | 22 | for i in 0..streamed_indices_length6 { |
1697 | 22 | let streamed_idx = streamed_indices.value(i); |
1698 | 22 | if mask.value(i) |
1699 | 12 | && !seen_as_true |
1700 | 8 | && !matched_indices.contains(&streamed_idx) |
1701 | 8 | { |
1702 | 8 | seen_as_true = true; |
1703 | 8 | filter_matched_indices.push(streamed_idx); |
1704 | 14 | } |
1705 | | |
1706 | | // Reset `seen_as_true` flag and calculate mask for the current streaming index |
1707 | | // - if within the batch it switched to next streaming index(e.g. from 0 to 1, or from 1 to 2) |
1708 | | // - if it is at the end of the all buffered batches for the given streaming index, 0 index comes last |
1709 | 22 | if (i < streamed_indices_length - 1 |
1710 | 16 | && streamed_idx != streamed_indices.value(i + 1)) |
1711 | 16 | || (i == streamed_indices_length - 1 |
1712 | 6 | && *scanning_buffered_offset == 0) |
1713 | | { |
1714 | 12 | corrected_mask.append_value( |
1715 | 12 | !matched_indices.contains(&streamed_idx) && !seen_as_true, |
1716 | | ); |
1717 | 12 | seen_as_true = false; |
1718 | 10 | } else { |
1719 | 10 | corrected_mask.append_value(false); |
1720 | 10 | } |
1721 | | } |
1722 | | |
1723 | 6 | Some((corrected_mask.finish(), filter_matched_indices)) |
1724 | | } |
1725 | 0 | _ => None, |
1726 | | } |
1727 | 13 | } |
1728 | | |
1729 | | /// Buffered data contains all buffered batches with one unique join key |
1730 | | #[derive(Debug, Default)] |
1731 | | struct BufferedData { |
1732 | | /// Buffered batches with the same key |
1733 | | pub batches: VecDeque<BufferedBatch>, |
1734 | | /// current scanning batch index used in join_partial() |
1735 | | pub scanning_batch_idx: usize, |
1736 | | /// current scanning offset used in join_partial() |
1737 | | pub scanning_offset: usize, |
1738 | | } |
1739 | | |
1740 | | impl BufferedData { |
1741 | 1.71k | pub fn head_batch(&self) -> &BufferedBatch { |
1742 | 1.71k | self.batches.front().unwrap() |
1743 | 1.71k | } |
1744 | | |
1745 | 1.80k | pub fn tail_batch(&self) -> &BufferedBatch { |
1746 | 1.80k | self.batches.back().unwrap() |
1747 | 1.80k | } |
1748 | | |
1749 | 249 | pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { |
1750 | 249 | self.batches.back_mut().unwrap() |
1751 | 249 | } |
1752 | | |
1753 | 488 | pub fn has_buffered_rows(&self) -> bool { |
1754 | 488 | self.batches.iter().any(|batch| !batch.range.is_empty()476 ) |
1755 | 488 | } |
1756 | | |
1757 | 569 | pub fn scanning_reset(&mut self) { |
1758 | 569 | self.scanning_batch_idx = 0; |
1759 | 569 | self.scanning_offset = 0; |
1760 | 569 | } |
1761 | | |
1762 | 607 | pub fn scanning_advance(&mut self) { |
1763 | 607 | self.scanning_offset += 1; |
1764 | 1.00k | while !self.scanning_finished() && self.scanning_batch_finished()797 { |
1765 | 399 | self.scanning_batch_idx += 1; |
1766 | 399 | self.scanning_offset = 0; |
1767 | 399 | } |
1768 | 607 | } |
1769 | | |
1770 | 1.40k | pub fn scanning_batch(&self) -> &BufferedBatch { |
1771 | 1.40k | &self.batches[self.scanning_batch_idx] |
1772 | 1.40k | } |
1773 | | |
1774 | 7 | pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch { |
1775 | 7 | &mut self.batches[self.scanning_batch_idx] |
1776 | 7 | } |
1777 | | |
1778 | 607 | pub fn scanning_idx(&self) -> usize { |
1779 | 607 | self.scanning_batch().range.start + self.scanning_offset |
1780 | 607 | } |
1781 | | |
1782 | 797 | pub fn scanning_batch_finished(&self) -> bool { |
1783 | 797 | self.scanning_offset == self.scanning_batch().range.len() |
1784 | 797 | } |
1785 | | |
1786 | 3.36k | pub fn scanning_finished(&self) -> bool { |
1787 | 3.36k | self.scanning_batch_idx == self.batches.len() |
1788 | 3.36k | } |
1789 | | |
1790 | 396 | pub fn scanning_finish(&mut self) { |
1791 | 396 | self.scanning_batch_idx = self.batches.len(); |
1792 | 396 | self.scanning_offset = 0; |
1793 | 396 | } |
1794 | | } |
1795 | | |
1796 | | /// Get join array refs of given batch and join columns |
1797 | 260 | fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayRef> { |
1798 | 260 | on_column |
1799 | 260 | .iter() |
1800 | 270 | .map(|c| { |
1801 | 270 | let num_rows = batch.num_rows(); |
1802 | 270 | let c = c.evaluate(batch).unwrap(); |
1803 | 270 | c.into_array(num_rows).unwrap() |
1804 | 270 | }) |
1805 | 260 | .collect() |
1806 | 260 | } |
1807 | | |
1808 | | /// Get comparison result of two rows of join arrays |
1809 | 476 | fn compare_join_arrays( |
1810 | 476 | left_arrays: &[ArrayRef], |
1811 | 476 | left: usize, |
1812 | 476 | right_arrays: &[ArrayRef], |
1813 | 476 | right: usize, |
1814 | 476 | sort_options: &[SortOptions], |
1815 | 476 | null_equals_null: bool, |
1816 | 476 | ) -> Result<Ordering> { |
1817 | 476 | let mut res = Ordering::Equal; |
1818 | 494 | for ((left_array, right_array), sort_options) in |
1819 | 476 | left_arrays.iter().zip(right_arrays).zip(sort_options) |
1820 | | { |
1821 | | macro_rules! compare_value { |
1822 | | ($T:ty) => {{ |
1823 | | let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); |
1824 | | let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); |
1825 | | match (left_array.is_null(left), right_array.is_null(right)) { |
1826 | | (false, false) => { |
1827 | | let left_value = &left_array.value(left); |
1828 | | let right_value = &right_array.value(right); |
1829 | | res = left_value.partial_cmp(right_value).unwrap(); |
1830 | | if sort_options.descending { |
1831 | | res = res.reverse(); |
1832 | | } |
1833 | | } |
1834 | | (true, false) => { |
1835 | | res = if sort_options.nulls_first { |
1836 | | Ordering::Less |
1837 | | } else { |
1838 | | Ordering::Greater |
1839 | | }; |
1840 | | } |
1841 | | (false, true) => { |
1842 | | res = if sort_options.nulls_first { |
1843 | | Ordering::Greater |
1844 | | } else { |
1845 | | Ordering::Less |
1846 | | }; |
1847 | | } |
1848 | | _ => { |
1849 | | res = if null_equals_null { |
1850 | | Ordering::Equal |
1851 | | } else { |
1852 | | Ordering::Less |
1853 | | }; |
1854 | | } |
1855 | | } |
1856 | | }}; |
1857 | | } |
1858 | | |
1859 | 494 | match left_array.data_type() { |
1860 | 0 | DataType::Null => {} |
1861 | 0 | DataType::Boolean => compare_value!(BooleanArray), |
1862 | 0 | DataType::Int8 => compare_value!(Int8Array), |
1863 | 0 | DataType::Int16 => compare_value!(Int16Array), |
1864 | 485 | DataType::Int32 => compare_value!(Int32Array), |
1865 | 0 | DataType::Int64 => compare_value!(Int64Array), |
1866 | 0 | DataType::UInt8 => compare_value!(UInt8Array), |
1867 | 0 | DataType::UInt16 => compare_value!(UInt16Array), |
1868 | 0 | DataType::UInt32 => compare_value!(UInt32Array), |
1869 | 0 | DataType::UInt64 => compare_value!(UInt64Array), |
1870 | 0 | DataType::Float32 => compare_value!(Float32Array), |
1871 | 0 | DataType::Float64 => compare_value!(Float64Array), |
1872 | 0 | DataType::Utf8 => compare_value!(StringArray), |
1873 | 0 | DataType::LargeUtf8 => compare_value!(LargeStringArray), |
1874 | 0 | DataType::Decimal128(..) => compare_value!(Decimal128Array), |
1875 | 0 | DataType::Timestamp(time_unit, None) => match time_unit { |
1876 | 0 | TimeUnit::Second => compare_value!(TimestampSecondArray), |
1877 | 0 | TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), |
1878 | 0 | TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), |
1879 | 0 | TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), |
1880 | | }, |
1881 | 4 | DataType::Date32 => compare_value!(Date32Array), |
1882 | 5 | DataType::Date64 => compare_value!(Date64Array), |
1883 | 0 | dt => { |
1884 | 0 | return not_impl_err!( |
1885 | 0 | "Unsupported data type in sort merge join comparator: {}", |
1886 | 0 | dt |
1887 | 0 | ); |
1888 | | } |
1889 | | } |
1890 | 494 | if !res.is_eq() { |
1891 | 188 | break; |
1892 | 306 | } |
1893 | | } |
1894 | 476 | Ok(res) |
1895 | 476 | } |
1896 | | |
1897 | | /// A faster version of compare_join_arrays() that only output whether |
1898 | | /// the given two rows are equal |
1899 | 249 | fn is_join_arrays_equal( |
1900 | 249 | left_arrays: &[ArrayRef], |
1901 | 249 | left: usize, |
1902 | 249 | right_arrays: &[ArrayRef], |
1903 | 249 | right: usize, |
1904 | 249 | ) -> Result<bool> { |
1905 | 249 | let mut is_equal = true; |
1906 | 252 | for (left_array, right_array) in left_arrays.iter().zip(right_arrays)249 { |
1907 | | macro_rules! compare_value { |
1908 | | ($T:ty) => {{ |
1909 | | match (left_array.is_null(left), right_array.is_null(right)) { |
1910 | | (false, false) => { |
1911 | | let left_array = |
1912 | | left_array.as_any().downcast_ref::<$T>().unwrap(); |
1913 | | let right_array = |
1914 | | right_array.as_any().downcast_ref::<$T>().unwrap(); |
1915 | | if left_array.value(left) != right_array.value(right) { |
1916 | | is_equal = false; |
1917 | | } |
1918 | | } |
1919 | | (true, false) => is_equal = false, |
1920 | | (false, true) => is_equal = false, |
1921 | | _ => {} |
1922 | | } |
1923 | | }}; |
1924 | | } |
1925 | | |
1926 | 252 | match left_array.data_type() { |
1927 | 0 | DataType::Null => {} |
1928 | 0 | DataType::Boolean => compare_value!(BooleanArray), |
1929 | 0 | DataType::Int8 => compare_value!(Int8Array), |
1930 | 0 | DataType::Int16 => compare_value!(Int16Array), |
1931 | 248 | DataType::Int32 => compare_value!(Int32Array), |
1932 | 0 | DataType::Int64 => compare_value!(Int64Array), |
1933 | 0 | DataType::UInt8 => compare_value!(UInt8Array), |
1934 | 0 | DataType::UInt16 => compare_value!(UInt16Array), |
1935 | 0 | DataType::UInt32 => compare_value!(UInt32Array), |
1936 | 0 | DataType::UInt64 => compare_value!(UInt64Array), |
1937 | 0 | DataType::Float32 => compare_value!(Float32Array), |
1938 | 0 | DataType::Float64 => compare_value!(Float64Array), |
1939 | 0 | DataType::Utf8 => compare_value!(StringArray), |
1940 | 0 | DataType::LargeUtf8 => compare_value!(LargeStringArray), |
1941 | 0 | DataType::Decimal128(..) => compare_value!(Decimal128Array), |
1942 | 0 | DataType::Timestamp(time_unit, None) => match time_unit { |
1943 | 0 | TimeUnit::Second => compare_value!(TimestampSecondArray), |
1944 | 0 | TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), |
1945 | 0 | TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), |
1946 | 0 | TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), |
1947 | | }, |
1948 | 2 | DataType::Date32 => compare_value!(Date32Array), |
1949 | 2 | DataType::Date64 => compare_value!(Date64Array), |
1950 | 0 | dt => { |
1951 | 0 | return not_impl_err!( |
1952 | 0 | "Unsupported data type in sort merge join comparator: {}", |
1953 | 0 | dt |
1954 | 0 | ); |
1955 | | } |
1956 | | } |
1957 | 252 | if !is_equal { |
1958 | 144 | return Ok(false); |
1959 | 108 | } |
1960 | | } |
1961 | 105 | Ok(true) |
1962 | 249 | } |
1963 | | |
1964 | | #[cfg(test)] |
1965 | | mod tests { |
1966 | | use std::sync::Arc; |
1967 | | |
1968 | | use arrow::array::{Date32Array, Date64Array, Int32Array}; |
1969 | | use arrow::compute::SortOptions; |
1970 | | use arrow::datatypes::{DataType, Field, Schema}; |
1971 | | use arrow::record_batch::RecordBatch; |
1972 | | use arrow_array::{BooleanArray, UInt64Array}; |
1973 | | use hashbrown::HashSet; |
1974 | | |
1975 | | use datafusion_common::JoinType::{LeftAnti, LeftSemi}; |
1976 | | use datafusion_common::{ |
1977 | | assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, |
1978 | | }; |
1979 | | use datafusion_execution::config::SessionConfig; |
1980 | | use datafusion_execution::disk_manager::DiskManagerConfig; |
1981 | | use datafusion_execution::runtime_env::RuntimeEnvBuilder; |
1982 | | use datafusion_execution::TaskContext; |
1983 | | |
1984 | | use crate::expressions::Column; |
1985 | | use crate::joins::sort_merge_join::get_filtered_join_mask; |
1986 | | use crate::joins::utils::JoinOn; |
1987 | | use crate::joins::SortMergeJoinExec; |
1988 | | use crate::memory::MemoryExec; |
1989 | | use crate::test::build_table_i32; |
1990 | | use crate::{common, ExecutionPlan}; |
1991 | | |
1992 | 28 | fn build_table( |
1993 | 28 | a: (&str, &Vec<i32>), |
1994 | 28 | b: (&str, &Vec<i32>), |
1995 | 28 | c: (&str, &Vec<i32>), |
1996 | 28 | ) -> Arc<dyn ExecutionPlan> { |
1997 | 28 | let batch = build_table_i32(a, b, c); |
1998 | 28 | let schema = batch.schema(); |
1999 | 28 | Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) |
2000 | 28 | } |
2001 | | |
2002 | 10 | fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> { |
2003 | 10 | let schema = batches.first().unwrap().schema(); |
2004 | 10 | Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) |
2005 | 10 | } |
2006 | | |
2007 | 2 | fn build_date_table( |
2008 | 2 | a: (&str, &Vec<i32>), |
2009 | 2 | b: (&str, &Vec<i32>), |
2010 | 2 | c: (&str, &Vec<i32>), |
2011 | 2 | ) -> Arc<dyn ExecutionPlan> { |
2012 | 2 | let schema = Schema::new(vec![ |
2013 | 2 | Field::new(a.0, DataType::Date32, false), |
2014 | 2 | Field::new(b.0, DataType::Date32, false), |
2015 | 2 | Field::new(c.0, DataType::Date32, false), |
2016 | 2 | ]); |
2017 | 2 | |
2018 | 2 | let batch = RecordBatch::try_new( |
2019 | 2 | Arc::new(schema), |
2020 | 2 | vec![ |
2021 | 2 | Arc::new(Date32Array::from(a.1.clone())), |
2022 | 2 | Arc::new(Date32Array::from(b.1.clone())), |
2023 | 2 | Arc::new(Date32Array::from(c.1.clone())), |
2024 | 2 | ], |
2025 | 2 | ) |
2026 | 2 | .unwrap(); |
2027 | 2 | |
2028 | 2 | let schema = batch.schema(); |
2029 | 2 | Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) |
2030 | 2 | } |
2031 | | |
2032 | 2 | fn build_date64_table( |
2033 | 2 | a: (&str, &Vec<i64>), |
2034 | 2 | b: (&str, &Vec<i64>), |
2035 | 2 | c: (&str, &Vec<i64>), |
2036 | 2 | ) -> Arc<dyn ExecutionPlan> { |
2037 | 2 | let schema = Schema::new(vec![ |
2038 | 2 | Field::new(a.0, DataType::Date64, false), |
2039 | 2 | Field::new(b.0, DataType::Date64, false), |
2040 | 2 | Field::new(c.0, DataType::Date64, false), |
2041 | 2 | ]); |
2042 | 2 | |
2043 | 2 | let batch = RecordBatch::try_new( |
2044 | 2 | Arc::new(schema), |
2045 | 2 | vec![ |
2046 | 2 | Arc::new(Date64Array::from(a.1.clone())), |
2047 | 2 | Arc::new(Date64Array::from(b.1.clone())), |
2048 | 2 | Arc::new(Date64Array::from(c.1.clone())), |
2049 | 2 | ], |
2050 | 2 | ) |
2051 | 2 | .unwrap(); |
2052 | 2 | |
2053 | 2 | let schema = batch.schema(); |
2054 | 2 | Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) |
2055 | 2 | } |
2056 | | |
2057 | | /// returns a table with 3 columns of i32 in memory |
2058 | 4 | pub fn build_table_i32_nullable( |
2059 | 4 | a: (&str, &Vec<Option<i32>>), |
2060 | 4 | b: (&str, &Vec<Option<i32>>), |
2061 | 4 | c: (&str, &Vec<Option<i32>>), |
2062 | 4 | ) -> Arc<dyn ExecutionPlan> { |
2063 | 4 | let schema = Arc::new(Schema::new(vec![ |
2064 | 4 | Field::new(a.0, DataType::Int32, true), |
2065 | 4 | Field::new(b.0, DataType::Int32, true), |
2066 | 4 | Field::new(c.0, DataType::Int32, true), |
2067 | 4 | ])); |
2068 | 4 | let batch = RecordBatch::try_new( |
2069 | 4 | Arc::clone(&schema), |
2070 | 4 | vec![ |
2071 | 4 | Arc::new(Int32Array::from(a.1.clone())), |
2072 | 4 | Arc::new(Int32Array::from(b.1.clone())), |
2073 | 4 | Arc::new(Int32Array::from(c.1.clone())), |
2074 | 4 | ], |
2075 | 4 | ) |
2076 | 4 | .unwrap(); |
2077 | 4 | Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) |
2078 | 4 | } |
2079 | | |
2080 | 1 | fn join( |
2081 | 1 | left: Arc<dyn ExecutionPlan>, |
2082 | 1 | right: Arc<dyn ExecutionPlan>, |
2083 | 1 | on: JoinOn, |
2084 | 1 | join_type: JoinType, |
2085 | 1 | ) -> Result<SortMergeJoinExec> { |
2086 | 1 | let sort_options = vec![SortOptions::default(); on.len()]; |
2087 | 1 | SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false) |
2088 | 1 | } |
2089 | | |
2090 | 78 | fn join_with_options( |
2091 | 78 | left: Arc<dyn ExecutionPlan>, |
2092 | 78 | right: Arc<dyn ExecutionPlan>, |
2093 | 78 | on: JoinOn, |
2094 | 78 | join_type: JoinType, |
2095 | 78 | sort_options: Vec<SortOptions>, |
2096 | 78 | null_equals_null: bool, |
2097 | 78 | ) -> Result<SortMergeJoinExec> { |
2098 | 78 | SortMergeJoinExec::try_new( |
2099 | 78 | left, |
2100 | 78 | right, |
2101 | 78 | on, |
2102 | 78 | None, |
2103 | 78 | join_type, |
2104 | 78 | sort_options, |
2105 | 78 | null_equals_null, |
2106 | 78 | ) |
2107 | 78 | } |
2108 | | |
2109 | 17 | async fn join_collect( |
2110 | 17 | left: Arc<dyn ExecutionPlan>, |
2111 | 17 | right: Arc<dyn ExecutionPlan>, |
2112 | 17 | on: JoinOn, |
2113 | 17 | join_type: JoinType, |
2114 | 17 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
2115 | 17 | let sort_options = vec![SortOptions::default(); on.len()]; |
2116 | 17 | join_collect_with_options(left, right, on, join_type, sort_options, false).await0 |
2117 | 17 | } |
2118 | | |
2119 | 18 | async fn join_collect_with_options( |
2120 | 18 | left: Arc<dyn ExecutionPlan>, |
2121 | 18 | right: Arc<dyn ExecutionPlan>, |
2122 | 18 | on: JoinOn, |
2123 | 18 | join_type: JoinType, |
2124 | 18 | sort_options: Vec<SortOptions>, |
2125 | 18 | null_equals_null: bool, |
2126 | 18 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
2127 | 18 | let task_ctx = Arc::new(TaskContext::default()); |
2128 | 18 | let join = join_with_options( |
2129 | 18 | left, |
2130 | 18 | right, |
2131 | 18 | on, |
2132 | 18 | join_type, |
2133 | 18 | sort_options, |
2134 | 18 | null_equals_null, |
2135 | 18 | )?0 ; |
2136 | 18 | let columns = columns(&join.schema()); |
2137 | | |
2138 | 18 | let stream = join.execute(0, task_ctx)?0 ; |
2139 | 18 | let batches = common::collect(stream).await0 ?0 ; |
2140 | 18 | Ok((columns, batches)) |
2141 | 18 | } |
2142 | | |
2143 | 1 | async fn join_collect_batch_size_equals_two( |
2144 | 1 | left: Arc<dyn ExecutionPlan>, |
2145 | 1 | right: Arc<dyn ExecutionPlan>, |
2146 | 1 | on: JoinOn, |
2147 | 1 | join_type: JoinType, |
2148 | 1 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
2149 | 1 | let task_ctx = TaskContext::default() |
2150 | 1 | .with_session_config(SessionConfig::new().with_batch_size(2)); |
2151 | 1 | let task_ctx = Arc::new(task_ctx); |
2152 | 1 | let join = join(left, right, on, join_type)?0 ; |
2153 | 1 | let columns = columns(&join.schema()); |
2154 | | |
2155 | 1 | let stream = join.execute(0, task_ctx)?0 ; |
2156 | 1 | let batches = common::collect(stream).await0 ?0 ; |
2157 | 1 | Ok((columns, batches)) |
2158 | 1 | } |
2159 | | |
2160 | | #[tokio::test] |
2161 | 1 | async fn join_inner_one() -> Result<()> { |
2162 | 1 | let left = build_table( |
2163 | 1 | ("a1", &vec![1, 2, 3]), |
2164 | 1 | ("b1", &vec![4, 5, 5]), // this has a repetition |
2165 | 1 | ("c1", &vec![7, 8, 9]), |
2166 | 1 | ); |
2167 | 1 | let right = build_table( |
2168 | 1 | ("a2", &vec![10, 20, 30]), |
2169 | 1 | ("b1", &vec![4, 5, 6]), |
2170 | 1 | ("c2", &vec![70, 80, 90]), |
2171 | 1 | ); |
2172 | 1 | |
2173 | 1 | let on = vec![( |
2174 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2175 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2176 | 1 | )]; |
2177 | 1 | |
2178 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2179 | 1 | |
2180 | 1 | let expected = [ |
2181 | 1 | "+----+----+----+----+----+----+", |
2182 | 1 | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2183 | 1 | "+----+----+----+----+----+----+", |
2184 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2185 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2186 | 1 | "| 3 | 5 | 9 | 20 | 5 | 80 |", |
2187 | 1 | "+----+----+----+----+----+----+", |
2188 | 1 | ]; |
2189 | 1 | // The output order is important as SMJ preserves sortedness |
2190 | 1 | assert_batches_eq!(expected, &batches); |
2191 | 1 | Ok(()) |
2192 | 1 | } |
2193 | | |
2194 | | #[tokio::test] |
2195 | 1 | async fn join_inner_two() -> Result<()> { |
2196 | 1 | let left = build_table( |
2197 | 1 | ("a1", &vec![1, 2, 2]), |
2198 | 1 | ("b2", &vec![1, 2, 2]), |
2199 | 1 | ("c1", &vec![7, 8, 9]), |
2200 | 1 | ); |
2201 | 1 | let right = build_table( |
2202 | 1 | ("a1", &vec![1, 2, 3]), |
2203 | 1 | ("b2", &vec![1, 2, 2]), |
2204 | 1 | ("c2", &vec![70, 80, 90]), |
2205 | 1 | ); |
2206 | 1 | let on = vec![ |
2207 | 1 | ( |
2208 | 1 | Arc::new(Column::new_with_schema("a1", &left.schema())?0 ) as _, |
2209 | 1 | Arc::new(Column::new_with_schema("a1", &right.schema())?0 ) as _, |
2210 | 1 | ), |
2211 | 1 | ( |
2212 | 1 | Arc::new(Column::new_with_schema("b2", &left.schema())?0 ) as _, |
2213 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2214 | 1 | ), |
2215 | 1 | ]; |
2216 | 1 | |
2217 | 1 | let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2218 | 1 | let expected = [ |
2219 | 1 | "+----+----+----+----+----+----+", |
2220 | 1 | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
2221 | 1 | "+----+----+----+----+----+----+", |
2222 | 1 | "| 1 | 1 | 7 | 1 | 1 | 70 |", |
2223 | 1 | "| 2 | 2 | 8 | 2 | 2 | 80 |", |
2224 | 1 | "| 2 | 2 | 9 | 2 | 2 | 80 |", |
2225 | 1 | "+----+----+----+----+----+----+", |
2226 | 1 | ]; |
2227 | 1 | // The output order is important as SMJ preserves sortedness |
2228 | 1 | assert_batches_eq!(expected, &batches); |
2229 | 1 | Ok(()) |
2230 | 1 | } |
2231 | | |
2232 | | #[tokio::test] |
2233 | 1 | async fn join_inner_two_two() -> Result<()> { |
2234 | 1 | let left = build_table( |
2235 | 1 | ("a1", &vec![1, 1, 2]), |
2236 | 1 | ("b2", &vec![1, 1, 2]), |
2237 | 1 | ("c1", &vec![7, 8, 9]), |
2238 | 1 | ); |
2239 | 1 | let right = build_table( |
2240 | 1 | ("a1", &vec![1, 1, 3]), |
2241 | 1 | ("b2", &vec![1, 1, 2]), |
2242 | 1 | ("c2", &vec![70, 80, 90]), |
2243 | 1 | ); |
2244 | 1 | let on = vec![ |
2245 | 1 | ( |
2246 | 1 | Arc::new(Column::new_with_schema("a1", &left.schema())?0 ) as _, |
2247 | 1 | Arc::new(Column::new_with_schema("a1", &right.schema())?0 ) as _, |
2248 | 1 | ), |
2249 | 1 | ( |
2250 | 1 | Arc::new(Column::new_with_schema("b2", &left.schema())?0 ) as _, |
2251 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2252 | 1 | ), |
2253 | 1 | ]; |
2254 | 1 | |
2255 | 1 | let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2256 | 1 | let expected = [ |
2257 | 1 | "+----+----+----+----+----+----+", |
2258 | 1 | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
2259 | 1 | "+----+----+----+----+----+----+", |
2260 | 1 | "| 1 | 1 | 7 | 1 | 1 | 70 |", |
2261 | 1 | "| 1 | 1 | 7 | 1 | 1 | 80 |", |
2262 | 1 | "| 1 | 1 | 8 | 1 | 1 | 70 |", |
2263 | 1 | "| 1 | 1 | 8 | 1 | 1 | 80 |", |
2264 | 1 | "+----+----+----+----+----+----+", |
2265 | 1 | ]; |
2266 | 1 | // The output order is important as SMJ preserves sortedness |
2267 | 1 | assert_batches_eq!(expected, &batches); |
2268 | 1 | Ok(()) |
2269 | 1 | } |
2270 | | |
2271 | | #[tokio::test] |
2272 | 1 | async fn join_inner_with_nulls() -> Result<()> { |
2273 | 1 | let left = build_table_i32_nullable( |
2274 | 1 | ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), |
2275 | 1 | ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field |
2276 | 1 | ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field |
2277 | 1 | ); |
2278 | 1 | let right = build_table_i32_nullable( |
2279 | 1 | ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), |
2280 | 1 | ("b2", &vec![None, Some(1), Some(2), Some(2)]), |
2281 | 1 | ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), |
2282 | 1 | ); |
2283 | 1 | let on = vec![ |
2284 | 1 | ( |
2285 | 1 | Arc::new(Column::new_with_schema("a1", &left.schema())?0 ) as _, |
2286 | 1 | Arc::new(Column::new_with_schema("a1", &right.schema())?0 ) as _, |
2287 | 1 | ), |
2288 | 1 | ( |
2289 | 1 | Arc::new(Column::new_with_schema("b2", &left.schema())?0 ) as _, |
2290 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2291 | 1 | ), |
2292 | 1 | ]; |
2293 | 1 | |
2294 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2295 | 1 | let expected = [ |
2296 | 1 | "+----+----+----+----+----+----+", |
2297 | 1 | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
2298 | 1 | "+----+----+----+----+----+----+", |
2299 | 1 | "| 1 | 1 | | 1 | 1 | 70 |", |
2300 | 1 | "| 2 | 2 | 8 | 2 | 2 | 80 |", |
2301 | 1 | "| 2 | 2 | 9 | 2 | 2 | 80 |", |
2302 | 1 | "+----+----+----+----+----+----+", |
2303 | 1 | ]; |
2304 | 1 | // The output order is important as SMJ preserves sortedness |
2305 | 1 | assert_batches_eq!(expected, &batches); |
2306 | 1 | Ok(()) |
2307 | 1 | } |
2308 | | |
2309 | | #[tokio::test] |
2310 | 1 | async fn join_inner_with_nulls_with_options() -> Result<()> { |
2311 | 1 | let left = build_table_i32_nullable( |
2312 | 1 | ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), |
2313 | 1 | ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field |
2314 | 1 | ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field |
2315 | 1 | ); |
2316 | 1 | let right = build_table_i32_nullable( |
2317 | 1 | ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), |
2318 | 1 | ("b2", &vec![Some(2), Some(2), Some(1), None]), |
2319 | 1 | ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), |
2320 | 1 | ); |
2321 | 1 | let on = vec![ |
2322 | 1 | ( |
2323 | 1 | Arc::new(Column::new_with_schema("a1", &left.schema())?0 ) as _, |
2324 | 1 | Arc::new(Column::new_with_schema("a1", &right.schema())?0 ) as _, |
2325 | 1 | ), |
2326 | 1 | ( |
2327 | 1 | Arc::new(Column::new_with_schema("b2", &left.schema())?0 ) as _, |
2328 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2329 | 1 | ), |
2330 | 1 | ]; |
2331 | 1 | let (_, batches) = join_collect_with_options( |
2332 | 1 | left, |
2333 | 1 | right, |
2334 | 1 | on, |
2335 | 1 | JoinType::Inner, |
2336 | 1 | vec![ |
2337 | 1 | SortOptions { |
2338 | 1 | descending: true, |
2339 | 1 | nulls_first: false, |
2340 | 1 | }; |
2341 | 1 | 2 |
2342 | 1 | ], |
2343 | 1 | true, |
2344 | 1 | ) |
2345 | 1 | .await0 ?0 ; |
2346 | 1 | let expected = [ |
2347 | 1 | "+----+----+----+----+----+----+", |
2348 | 1 | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
2349 | 1 | "+----+----+----+----+----+----+", |
2350 | 1 | "| 2 | 2 | 9 | 2 | 2 | 80 |", |
2351 | 1 | "| 2 | 2 | 8 | 2 | 2 | 80 |", |
2352 | 1 | "| 1 | 1 | | 1 | 1 | 70 |", |
2353 | 1 | "| 1 | | 1 | 1 | | 10 |", |
2354 | 1 | "+----+----+----+----+----+----+", |
2355 | 1 | ]; |
2356 | 1 | // The output order is important as SMJ preserves sortedness |
2357 | 1 | assert_batches_eq!(expected, &batches); |
2358 | 1 | Ok(()) |
2359 | 1 | } |
2360 | | |
2361 | | #[tokio::test] |
2362 | 1 | async fn join_inner_output_two_batches() -> Result<()> { |
2363 | 1 | let left = build_table( |
2364 | 1 | ("a1", &vec![1, 2, 2]), |
2365 | 1 | ("b2", &vec![1, 2, 2]), |
2366 | 1 | ("c1", &vec![7, 8, 9]), |
2367 | 1 | ); |
2368 | 1 | let right = build_table( |
2369 | 1 | ("a1", &vec![1, 2, 3]), |
2370 | 1 | ("b2", &vec![1, 2, 2]), |
2371 | 1 | ("c2", &vec![70, 80, 90]), |
2372 | 1 | ); |
2373 | 1 | let on = vec![ |
2374 | 1 | ( |
2375 | 1 | Arc::new(Column::new_with_schema("a1", &left.schema())?0 ) as _, |
2376 | 1 | Arc::new(Column::new_with_schema("a1", &right.schema())?0 ) as _, |
2377 | 1 | ), |
2378 | 1 | ( |
2379 | 1 | Arc::new(Column::new_with_schema("b2", &left.schema())?0 ) as _, |
2380 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2381 | 1 | ), |
2382 | 1 | ]; |
2383 | 1 | |
2384 | 1 | let (_, batches) = |
2385 | 1 | join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await0 ?0 ; |
2386 | 1 | let expected = [ |
2387 | 1 | "+----+----+----+----+----+----+", |
2388 | 1 | "| a1 | b2 | c1 | a1 | b2 | c2 |", |
2389 | 1 | "+----+----+----+----+----+----+", |
2390 | 1 | "| 1 | 1 | 7 | 1 | 1 | 70 |", |
2391 | 1 | "| 2 | 2 | 8 | 2 | 2 | 80 |", |
2392 | 1 | "| 2 | 2 | 9 | 2 | 2 | 80 |", |
2393 | 1 | "+----+----+----+----+----+----+", |
2394 | 1 | ]; |
2395 | 1 | assert_eq!(batches.len(), 2); |
2396 | 1 | assert_eq!(batches[0].num_rows(), 2); |
2397 | 1 | assert_eq!(batches[1].num_rows(), 1); |
2398 | 1 | // The output order is important as SMJ preserves sortedness |
2399 | 1 | assert_batches_eq!(expected, &batches); |
2400 | 1 | Ok(()) |
2401 | 1 | } |
2402 | | |
2403 | | #[tokio::test] |
2404 | 1 | async fn join_left_one() -> Result<()> { |
2405 | 1 | let left = build_table( |
2406 | 1 | ("a1", &vec![1, 2, 3]), |
2407 | 1 | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
2408 | 1 | ("c1", &vec![7, 8, 9]), |
2409 | 1 | ); |
2410 | 1 | let right = build_table( |
2411 | 1 | ("a2", &vec![10, 20, 30]), |
2412 | 1 | ("b1", &vec![4, 5, 6]), |
2413 | 1 | ("c2", &vec![70, 80, 90]), |
2414 | 1 | ); |
2415 | 1 | let on = vec![( |
2416 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2417 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2418 | 1 | )]; |
2419 | 1 | |
2420 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Left).await0 ?0 ; |
2421 | 1 | let expected = [ |
2422 | 1 | "+----+----+----+----+----+----+", |
2423 | 1 | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2424 | 1 | "+----+----+----+----+----+----+", |
2425 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2426 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2427 | 1 | "| 3 | 7 | 9 | | | |", |
2428 | 1 | "+----+----+----+----+----+----+", |
2429 | 1 | ]; |
2430 | 1 | // The output order is important as SMJ preserves sortedness |
2431 | 1 | assert_batches_eq!(expected, &batches); |
2432 | 1 | Ok(()) |
2433 | 1 | } |
2434 | | |
2435 | | #[tokio::test] |
2436 | 1 | async fn join_right_one() -> Result<()> { |
2437 | 1 | let left = build_table( |
2438 | 1 | ("a1", &vec![1, 2, 3]), |
2439 | 1 | ("b1", &vec![4, 5, 7]), |
2440 | 1 | ("c1", &vec![7, 8, 9]), |
2441 | 1 | ); |
2442 | 1 | let right = build_table( |
2443 | 1 | ("a2", &vec![10, 20, 30]), |
2444 | 1 | ("b1", &vec![4, 5, 6]), // 6 does not exist on the left |
2445 | 1 | ("c2", &vec![70, 80, 90]), |
2446 | 1 | ); |
2447 | 1 | let on = vec![( |
2448 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2449 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2450 | 1 | )]; |
2451 | 1 | |
2452 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Right).await0 ?0 ; |
2453 | 1 | let expected = [ |
2454 | 1 | "+----+----+----+----+----+----+", |
2455 | 1 | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2456 | 1 | "+----+----+----+----+----+----+", |
2457 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2458 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2459 | 1 | "| | | | 30 | 6 | 90 |", |
2460 | 1 | "+----+----+----+----+----+----+", |
2461 | 1 | ]; |
2462 | 1 | // The output order is important as SMJ preserves sortedness |
2463 | 1 | assert_batches_eq!(expected, &batches); |
2464 | 1 | Ok(()) |
2465 | 1 | } |
2466 | | |
2467 | | #[tokio::test] |
2468 | 1 | async fn join_full_one() -> Result<()> { |
2469 | 1 | let left = build_table( |
2470 | 1 | ("a1", &vec![1, 2, 3]), |
2471 | 1 | ("b1", &vec![4, 5, 7]), // 7 does not exist on the right |
2472 | 1 | ("c1", &vec![7, 8, 9]), |
2473 | 1 | ); |
2474 | 1 | let right = build_table( |
2475 | 1 | ("a2", &vec![10, 20, 30]), |
2476 | 1 | ("b2", &vec![4, 5, 6]), |
2477 | 1 | ("c2", &vec![70, 80, 90]), |
2478 | 1 | ); |
2479 | 1 | let on = vec![( |
2480 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _, |
2481 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, |
2482 | 1 | )]; |
2483 | 1 | |
2484 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Full).await0 ?0 ; |
2485 | 1 | let expected = [ |
2486 | 1 | "+----+----+----+----+----+----+", |
2487 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2488 | 1 | "+----+----+----+----+----+----+", |
2489 | 1 | "| | | | 30 | 6 | 90 |", |
2490 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2491 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2492 | 1 | "| 3 | 7 | 9 | | | |", |
2493 | 1 | "+----+----+----+----+----+----+", |
2494 | 1 | ]; |
2495 | 1 | assert_batches_sorted_eq!(expected, &batches); |
2496 | 1 | Ok(()) |
2497 | 1 | } |
2498 | | |
2499 | | #[tokio::test] |
2500 | 1 | async fn join_anti() -> Result<()> { |
2501 | 1 | let left = build_table( |
2502 | 1 | ("a1", &vec![1, 2, 2, 3, 5]), |
2503 | 1 | ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right |
2504 | 1 | ("c1", &vec![7, 8, 8, 9, 11]), |
2505 | 1 | ); |
2506 | 1 | let right = build_table( |
2507 | 1 | ("a2", &vec![10, 20, 30]), |
2508 | 1 | ("b1", &vec![4, 5, 6]), |
2509 | 1 | ("c2", &vec![70, 80, 90]), |
2510 | 1 | ); |
2511 | 1 | let on = vec![( |
2512 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2513 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2514 | 1 | )]; |
2515 | 1 | |
2516 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await0 ?0 ; |
2517 | 1 | let expected = [ |
2518 | 1 | "+----+----+----+", |
2519 | 1 | "| a1 | b1 | c1 |", |
2520 | 1 | "+----+----+----+", |
2521 | 1 | "| 3 | 7 | 9 |", |
2522 | 1 | "| 5 | 7 | 11 |", |
2523 | 1 | "+----+----+----+", |
2524 | 1 | ]; |
2525 | 1 | // The output order is important as SMJ preserves sortedness |
2526 | 1 | assert_batches_eq!(expected, &batches); |
2527 | 1 | Ok(()) |
2528 | 1 | } |
2529 | | |
2530 | | #[tokio::test] |
2531 | 1 | async fn join_semi() -> Result<()> { |
2532 | 1 | let left = build_table( |
2533 | 1 | ("a1", &vec![1, 2, 2, 3]), |
2534 | 1 | ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right |
2535 | 1 | ("c1", &vec![7, 8, 8, 9]), |
2536 | 1 | ); |
2537 | 1 | let right = build_table( |
2538 | 1 | ("a2", &vec![10, 20, 30]), |
2539 | 1 | ("b1", &vec![4, 5, 6]), // 5 is double on the right |
2540 | 1 | ("c2", &vec![70, 80, 90]), |
2541 | 1 | ); |
2542 | 1 | let on = vec![( |
2543 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2544 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2545 | 1 | )]; |
2546 | 1 | |
2547 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await0 ?0 ; |
2548 | 1 | let expected = [ |
2549 | 1 | "+----+----+----+", |
2550 | 1 | "| a1 | b1 | c1 |", |
2551 | 1 | "+----+----+----+", |
2552 | 1 | "| 1 | 4 | 7 |", |
2553 | 1 | "| 2 | 5 | 8 |", |
2554 | 1 | "| 2 | 5 | 8 |", |
2555 | 1 | "+----+----+----+", |
2556 | 1 | ]; |
2557 | 1 | // The output order is important as SMJ preserves sortedness |
2558 | 1 | assert_batches_eq!(expected, &batches); |
2559 | 1 | Ok(()) |
2560 | 1 | } |
2561 | | |
2562 | | #[tokio::test] |
2563 | 1 | async fn join_with_duplicated_column_names() -> Result<()> { |
2564 | 1 | let left = build_table( |
2565 | 1 | ("a", &vec![1, 2, 3]), |
2566 | 1 | ("b", &vec![4, 5, 7]), |
2567 | 1 | ("c", &vec![7, 8, 9]), |
2568 | 1 | ); |
2569 | 1 | let right = build_table( |
2570 | 1 | ("a", &vec![10, 20, 30]), |
2571 | 1 | ("b", &vec![1, 2, 7]), |
2572 | 1 | ("c", &vec![70, 80, 90]), |
2573 | 1 | ); |
2574 | 1 | let on = vec![( |
2575 | 1 | // join on a=b so there are duplicate column names on unjoined columns |
2576 | 1 | Arc::new(Column::new_with_schema("a", &left.schema())?0 ) as _, |
2577 | 1 | Arc::new(Column::new_with_schema("b", &right.schema())?0 ) as _, |
2578 | 1 | )]; |
2579 | 1 | |
2580 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2581 | 1 | let expected = [ |
2582 | 1 | "+---+---+---+----+---+----+", |
2583 | 1 | "| a | b | c | a | b | c |", |
2584 | 1 | "+---+---+---+----+---+----+", |
2585 | 1 | "| 1 | 4 | 7 | 10 | 1 | 70 |", |
2586 | 1 | "| 2 | 5 | 8 | 20 | 2 | 80 |", |
2587 | 1 | "+---+---+---+----+---+----+", |
2588 | 1 | ]; |
2589 | 1 | // The output order is important as SMJ preserves sortedness |
2590 | 1 | assert_batches_eq!(expected, &batches); |
2591 | 1 | Ok(()) |
2592 | 1 | } |
2593 | | |
2594 | | #[tokio::test] |
2595 | 1 | async fn join_date32() -> Result<()> { |
2596 | 1 | let left = build_date_table( |
2597 | 1 | ("a1", &vec![1, 2, 3]), |
2598 | 1 | ("b1", &vec![19107, 19108, 19108]), // this has a repetition |
2599 | 1 | ("c1", &vec![7, 8, 9]), |
2600 | 1 | ); |
2601 | 1 | let right = build_date_table( |
2602 | 1 | ("a2", &vec![10, 20, 30]), |
2603 | 1 | ("b1", &vec![19107, 19108, 19109]), |
2604 | 1 | ("c2", &vec![70, 80, 90]), |
2605 | 1 | ); |
2606 | 1 | |
2607 | 1 | let on = vec![( |
2608 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2609 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2610 | 1 | )]; |
2611 | 1 | |
2612 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2613 | 1 | |
2614 | 1 | let expected = ["+------------+------------+------------+------------+------------+------------+", |
2615 | 1 | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2616 | 1 | "+------------+------------+------------+------------+------------+------------+", |
2617 | 1 | "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", |
2618 | 1 | "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", |
2619 | 1 | "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", |
2620 | 1 | "+------------+------------+------------+------------+------------+------------+"]; |
2621 | 1 | // The output order is important as SMJ preserves sortedness |
2622 | 1 | assert_batches_eq!(expected, &batches); |
2623 | 1 | Ok(()) |
2624 | 1 | } |
2625 | | |
2626 | | #[tokio::test] |
2627 | 1 | async fn join_date64() -> Result<()> { |
2628 | 1 | let left = build_date64_table( |
2629 | 1 | ("a1", &vec![1, 2, 3]), |
2630 | 1 | ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition |
2631 | 1 | ("c1", &vec![7, 8, 9]), |
2632 | 1 | ); |
2633 | 1 | let right = build_date64_table( |
2634 | 1 | ("a2", &vec![10, 20, 30]), |
2635 | 1 | ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), |
2636 | 1 | ("c2", &vec![70, 80, 90]), |
2637 | 1 | ); |
2638 | 1 | |
2639 | 1 | let on = vec![( |
2640 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2641 | 1 | Arc::new(Column::new_with_schema("b1", &right.schema())?0 ) as _, |
2642 | 1 | )]; |
2643 | 1 | |
2644 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Inner).await0 ?0 ; |
2645 | 1 | |
2646 | 1 | let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", |
2647 | 1 | "| a1 | b1 | c1 | a2 | b1 | c2 |", |
2648 | 1 | "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", |
2649 | 1 | "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", |
2650 | 1 | "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", |
2651 | 1 | "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", |
2652 | 1 | "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+"]; |
2653 | 1 | // The output order is important as SMJ preserves sortedness |
2654 | 1 | assert_batches_eq!(expected, &batches); |
2655 | 1 | Ok(()) |
2656 | 1 | } |
2657 | | |
2658 | | #[tokio::test] |
2659 | 1 | async fn join_left_sort_order() -> Result<()> { |
2660 | 1 | let left = build_table( |
2661 | 1 | ("a1", &vec![0, 1, 2, 3, 4, 5]), |
2662 | 1 | ("b1", &vec![3, 4, 5, 6, 6, 7]), |
2663 | 1 | ("c1", &vec![4, 5, 6, 7, 8, 9]), |
2664 | 1 | ); |
2665 | 1 | let right = build_table( |
2666 | 1 | ("a2", &vec![0, 10, 20, 30, 40]), |
2667 | 1 | ("b2", &vec![2, 4, 6, 6, 8]), |
2668 | 1 | ("c2", &vec![50, 60, 70, 80, 90]), |
2669 | 1 | ); |
2670 | 1 | let on = vec![( |
2671 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2672 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2673 | 1 | )]; |
2674 | 1 | |
2675 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Left).await0 ?0 ; |
2676 | 1 | let expected = [ |
2677 | 1 | "+----+----+----+----+----+----+", |
2678 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2679 | 1 | "+----+----+----+----+----+----+", |
2680 | 1 | "| 0 | 3 | 4 | | | |", |
2681 | 1 | "| 1 | 4 | 5 | 10 | 4 | 60 |", |
2682 | 1 | "| 2 | 5 | 6 | | | |", |
2683 | 1 | "| 3 | 6 | 7 | 20 | 6 | 70 |", |
2684 | 1 | "| 3 | 6 | 7 | 30 | 6 | 80 |", |
2685 | 1 | "| 4 | 6 | 8 | 20 | 6 | 70 |", |
2686 | 1 | "| 4 | 6 | 8 | 30 | 6 | 80 |", |
2687 | 1 | "| 5 | 7 | 9 | | | |", |
2688 | 1 | "+----+----+----+----+----+----+", |
2689 | 1 | ]; |
2690 | 1 | assert_batches_eq!(expected, &batches); |
2691 | 1 | Ok(()) |
2692 | 1 | } |
2693 | | |
2694 | | #[tokio::test] |
2695 | 1 | async fn join_right_sort_order() -> Result<()> { |
2696 | 1 | let left = build_table( |
2697 | 1 | ("a1", &vec![0, 1, 2, 3]), |
2698 | 1 | ("b1", &vec![3, 4, 5, 7]), |
2699 | 1 | ("c1", &vec![6, 7, 8, 9]), |
2700 | 1 | ); |
2701 | 1 | let right = build_table( |
2702 | 1 | ("a2", &vec![0, 10, 20, 30]), |
2703 | 1 | ("b2", &vec![2, 4, 5, 6]), |
2704 | 1 | ("c2", &vec![60, 70, 80, 90]), |
2705 | 1 | ); |
2706 | 1 | let on = vec![( |
2707 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2708 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2709 | 1 | )]; |
2710 | 1 | |
2711 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Right).await0 ?0 ; |
2712 | 1 | let expected = [ |
2713 | 1 | "+----+----+----+----+----+----+", |
2714 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2715 | 1 | "+----+----+----+----+----+----+", |
2716 | 1 | "| | | | 0 | 2 | 60 |", |
2717 | 1 | "| 1 | 4 | 7 | 10 | 4 | 70 |", |
2718 | 1 | "| 2 | 5 | 8 | 20 | 5 | 80 |", |
2719 | 1 | "| | | | 30 | 6 | 90 |", |
2720 | 1 | "+----+----+----+----+----+----+", |
2721 | 1 | ]; |
2722 | 1 | assert_batches_eq!(expected, &batches); |
2723 | 1 | Ok(()) |
2724 | 1 | } |
2725 | | |
2726 | | #[tokio::test] |
2727 | 1 | async fn join_left_multiple_batches() -> Result<()> { |
2728 | 1 | let left_batch_1 = build_table_i32( |
2729 | 1 | ("a1", &vec![0, 1, 2]), |
2730 | 1 | ("b1", &vec![3, 4, 5]), |
2731 | 1 | ("c1", &vec![4, 5, 6]), |
2732 | 1 | ); |
2733 | 1 | let left_batch_2 = build_table_i32( |
2734 | 1 | ("a1", &vec![3, 4, 5, 6]), |
2735 | 1 | ("b1", &vec![6, 6, 7, 9]), |
2736 | 1 | ("c1", &vec![7, 8, 9, 9]), |
2737 | 1 | ); |
2738 | 1 | let right_batch_1 = build_table_i32( |
2739 | 1 | ("a2", &vec![0, 10, 20]), |
2740 | 1 | ("b2", &vec![2, 4, 6]), |
2741 | 1 | ("c2", &vec![50, 60, 70]), |
2742 | 1 | ); |
2743 | 1 | let right_batch_2 = build_table_i32( |
2744 | 1 | ("a2", &vec![30, 40]), |
2745 | 1 | ("b2", &vec![6, 8]), |
2746 | 1 | ("c2", &vec![80, 90]), |
2747 | 1 | ); |
2748 | 1 | let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); |
2749 | 1 | let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); |
2750 | 1 | let on = vec![( |
2751 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2752 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2753 | 1 | )]; |
2754 | 1 | |
2755 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Left).await0 ?0 ; |
2756 | 1 | let expected = vec![ |
2757 | 1 | "+----+----+----+----+----+----+", |
2758 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2759 | 1 | "+----+----+----+----+----+----+", |
2760 | 1 | "| 0 | 3 | 4 | | | |", |
2761 | 1 | "| 1 | 4 | 5 | 10 | 4 | 60 |", |
2762 | 1 | "| 2 | 5 | 6 | | | |", |
2763 | 1 | "| 3 | 6 | 7 | 20 | 6 | 70 |", |
2764 | 1 | "| 3 | 6 | 7 | 30 | 6 | 80 |", |
2765 | 1 | "| 4 | 6 | 8 | 20 | 6 | 70 |", |
2766 | 1 | "| 4 | 6 | 8 | 30 | 6 | 80 |", |
2767 | 1 | "| 5 | 7 | 9 | | | |", |
2768 | 1 | "| 6 | 9 | 9 | | | |", |
2769 | 1 | "+----+----+----+----+----+----+", |
2770 | 1 | ]; |
2771 | 1 | assert_batches_eq!(expected, &batches); |
2772 | 1 | Ok(()) |
2773 | 1 | } |
2774 | | |
2775 | | #[tokio::test] |
2776 | 1 | async fn join_right_multiple_batches() -> Result<()> { |
2777 | 1 | let right_batch_1 = build_table_i32( |
2778 | 1 | ("a2", &vec![0, 1, 2]), |
2779 | 1 | ("b2", &vec![3, 4, 5]), |
2780 | 1 | ("c2", &vec![4, 5, 6]), |
2781 | 1 | ); |
2782 | 1 | let right_batch_2 = build_table_i32( |
2783 | 1 | ("a2", &vec![3, 4, 5, 6]), |
2784 | 1 | ("b2", &vec![6, 6, 7, 9]), |
2785 | 1 | ("c2", &vec![7, 8, 9, 9]), |
2786 | 1 | ); |
2787 | 1 | let left_batch_1 = build_table_i32( |
2788 | 1 | ("a1", &vec![0, 10, 20]), |
2789 | 1 | ("b1", &vec![2, 4, 6]), |
2790 | 1 | ("c1", &vec![50, 60, 70]), |
2791 | 1 | ); |
2792 | 1 | let left_batch_2 = build_table_i32( |
2793 | 1 | ("a1", &vec![30, 40]), |
2794 | 1 | ("b1", &vec![6, 8]), |
2795 | 1 | ("c1", &vec![80, 90]), |
2796 | 1 | ); |
2797 | 1 | let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); |
2798 | 1 | let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); |
2799 | 1 | let on = vec![( |
2800 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2801 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2802 | 1 | )]; |
2803 | 1 | |
2804 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Right).await0 ?0 ; |
2805 | 1 | let expected = vec![ |
2806 | 1 | "+----+----+----+----+----+----+", |
2807 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2808 | 1 | "+----+----+----+----+----+----+", |
2809 | 1 | "| | | | 0 | 3 | 4 |", |
2810 | 1 | "| 10 | 4 | 60 | 1 | 4 | 5 |", |
2811 | 1 | "| | | | 2 | 5 | 6 |", |
2812 | 1 | "| 20 | 6 | 70 | 3 | 6 | 7 |", |
2813 | 1 | "| 30 | 6 | 80 | 3 | 6 | 7 |", |
2814 | 1 | "| 20 | 6 | 70 | 4 | 6 | 8 |", |
2815 | 1 | "| 30 | 6 | 80 | 4 | 6 | 8 |", |
2816 | 1 | "| | | | 5 | 7 | 9 |", |
2817 | 1 | "| | | | 6 | 9 | 9 |", |
2818 | 1 | "+----+----+----+----+----+----+", |
2819 | 1 | ]; |
2820 | 1 | assert_batches_eq!(expected, &batches); |
2821 | 1 | Ok(()) |
2822 | 1 | } |
2823 | | |
2824 | | #[tokio::test] |
2825 | 1 | async fn join_full_multiple_batches() -> Result<()> { |
2826 | 1 | let left_batch_1 = build_table_i32( |
2827 | 1 | ("a1", &vec![0, 1, 2]), |
2828 | 1 | ("b1", &vec![3, 4, 5]), |
2829 | 1 | ("c1", &vec![4, 5, 6]), |
2830 | 1 | ); |
2831 | 1 | let left_batch_2 = build_table_i32( |
2832 | 1 | ("a1", &vec![3, 4, 5, 6]), |
2833 | 1 | ("b1", &vec![6, 6, 7, 9]), |
2834 | 1 | ("c1", &vec![7, 8, 9, 9]), |
2835 | 1 | ); |
2836 | 1 | let right_batch_1 = build_table_i32( |
2837 | 1 | ("a2", &vec![0, 10, 20]), |
2838 | 1 | ("b2", &vec![2, 4, 6]), |
2839 | 1 | ("c2", &vec![50, 60, 70]), |
2840 | 1 | ); |
2841 | 1 | let right_batch_2 = build_table_i32( |
2842 | 1 | ("a2", &vec![30, 40]), |
2843 | 1 | ("b2", &vec![6, 8]), |
2844 | 1 | ("c2", &vec![80, 90]), |
2845 | 1 | ); |
2846 | 1 | let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); |
2847 | 1 | let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); |
2848 | 1 | let on = vec![( |
2849 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2850 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2851 | 1 | )]; |
2852 | 1 | |
2853 | 1 | let (_, batches) = join_collect(left, right, on, JoinType::Full).await0 ?0 ; |
2854 | 1 | let expected = vec![ |
2855 | 1 | "+----+----+----+----+----+----+", |
2856 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
2857 | 1 | "+----+----+----+----+----+----+", |
2858 | 1 | "| | | | 0 | 2 | 50 |", |
2859 | 1 | "| | | | 40 | 8 | 90 |", |
2860 | 1 | "| 0 | 3 | 4 | | | |", |
2861 | 1 | "| 1 | 4 | 5 | 10 | 4 | 60 |", |
2862 | 1 | "| 2 | 5 | 6 | | | |", |
2863 | 1 | "| 3 | 6 | 7 | 20 | 6 | 70 |", |
2864 | 1 | "| 3 | 6 | 7 | 30 | 6 | 80 |", |
2865 | 1 | "| 4 | 6 | 8 | 20 | 6 | 70 |", |
2866 | 1 | "| 4 | 6 | 8 | 30 | 6 | 80 |", |
2867 | 1 | "| 5 | 7 | 9 | | | |", |
2868 | 1 | "| 6 | 9 | 9 | | | |", |
2869 | 1 | "+----+----+----+----+----+----+", |
2870 | 1 | ]; |
2871 | 1 | assert_batches_sorted_eq!(expected, &batches); |
2872 | 1 | Ok(()) |
2873 | 1 | } |
2874 | | |
2875 | | #[tokio::test] |
2876 | 1 | async fn overallocation_single_batch_no_spill() -> Result<()> { |
2877 | 1 | let left = build_table( |
2878 | 1 | ("a1", &vec![0, 1, 2, 3, 4, 5]), |
2879 | 1 | ("b1", &vec![1, 2, 3, 4, 5, 6]), |
2880 | 1 | ("c1", &vec![4, 5, 6, 7, 8, 9]), |
2881 | 1 | ); |
2882 | 1 | let right = build_table( |
2883 | 1 | ("a2", &vec![0, 10, 20, 30, 40]), |
2884 | 1 | ("b2", &vec![1, 3, 4, 6, 8]), |
2885 | 1 | ("c2", &vec![50, 60, 70, 80, 90]), |
2886 | 1 | ); |
2887 | 1 | let on = vec![( |
2888 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2889 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2890 | 1 | )]; |
2891 | 1 | let sort_options = vec![SortOptions::default(); on.len()]; |
2892 | 1 | |
2893 | 1 | let join_types = vec![ |
2894 | 1 | JoinType::Inner, |
2895 | 1 | JoinType::Left, |
2896 | 1 | JoinType::Right, |
2897 | 1 | JoinType::Full, |
2898 | 1 | JoinType::LeftSemi, |
2899 | 1 | JoinType::LeftAnti, |
2900 | 1 | ]; |
2901 | 1 | |
2902 | 1 | // Disable DiskManager to prevent spilling |
2903 | 1 | let runtime = RuntimeEnvBuilder::new() |
2904 | 1 | .with_memory_limit(100, 1.0) |
2905 | 1 | .with_disk_manager(DiskManagerConfig::Disabled) |
2906 | 1 | .build_arc()?0 ; |
2907 | 1 | let session_config = SessionConfig::default().with_batch_size(50); |
2908 | 1 | |
2909 | 7 | for join_type6 in join_types { |
2910 | 6 | let task_ctx = TaskContext::default() |
2911 | 6 | .with_session_config(session_config.clone()) |
2912 | 6 | .with_runtime(Arc::clone(&runtime)); |
2913 | 6 | let task_ctx = Arc::new(task_ctx); |
2914 | 1 | |
2915 | 6 | let join = join_with_options( |
2916 | 6 | Arc::clone(&left), |
2917 | 6 | Arc::clone(&right), |
2918 | 6 | on.clone(), |
2919 | 6 | join_type, |
2920 | 6 | sort_options.clone(), |
2921 | 6 | false, |
2922 | 6 | )?0 ; |
2923 | 1 | |
2924 | 6 | let stream = join.execute(0, task_ctx)?0 ; |
2925 | 6 | let err = common::collect(stream).await0 .unwrap_err(); |
2926 | 6 | |
2927 | 6 | assert_contains!(err.to_string(), "Failed to allocate additional"); |
2928 | 6 | assert_contains!(err.to_string(), "SMJStream[0]"); |
2929 | 6 | assert_contains!(err.to_string(), "Disk spilling disabled"); |
2930 | 6 | assert!(join.metrics().is_some()); |
2931 | 6 | assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); |
2932 | 6 | assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); |
2933 | 6 | assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); |
2934 | 1 | } |
2935 | 1 | |
2936 | 1 | Ok(()) |
2937 | 1 | } |
2938 | | |
2939 | | #[tokio::test] |
2940 | 1 | async fn overallocation_multi_batch_no_spill() -> Result<()> { |
2941 | 1 | let left_batch_1 = build_table_i32( |
2942 | 1 | ("a1", &vec![0, 1]), |
2943 | 1 | ("b1", &vec![1, 1]), |
2944 | 1 | ("c1", &vec![4, 5]), |
2945 | 1 | ); |
2946 | 1 | let left_batch_2 = build_table_i32( |
2947 | 1 | ("a1", &vec![2, 3]), |
2948 | 1 | ("b1", &vec![1, 1]), |
2949 | 1 | ("c1", &vec![6, 7]), |
2950 | 1 | ); |
2951 | 1 | let left_batch_3 = build_table_i32( |
2952 | 1 | ("a1", &vec![4, 5]), |
2953 | 1 | ("b1", &vec![1, 1]), |
2954 | 1 | ("c1", &vec![8, 9]), |
2955 | 1 | ); |
2956 | 1 | let right_batch_1 = build_table_i32( |
2957 | 1 | ("a2", &vec![0, 10]), |
2958 | 1 | ("b2", &vec![1, 1]), |
2959 | 1 | ("c2", &vec![50, 60]), |
2960 | 1 | ); |
2961 | 1 | let right_batch_2 = build_table_i32( |
2962 | 1 | ("a2", &vec![20, 30]), |
2963 | 1 | ("b2", &vec![1, 1]), |
2964 | 1 | ("c2", &vec![70, 80]), |
2965 | 1 | ); |
2966 | 1 | let right_batch_3 = |
2967 | 1 | build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); |
2968 | 1 | let left = |
2969 | 1 | build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); |
2970 | 1 | let right = |
2971 | 1 | build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); |
2972 | 1 | let on = vec![( |
2973 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
2974 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
2975 | 1 | )]; |
2976 | 1 | let sort_options = vec![SortOptions::default(); on.len()]; |
2977 | 1 | |
2978 | 1 | let join_types = vec![ |
2979 | 1 | JoinType::Inner, |
2980 | 1 | JoinType::Left, |
2981 | 1 | JoinType::Right, |
2982 | 1 | JoinType::Full, |
2983 | 1 | JoinType::LeftSemi, |
2984 | 1 | JoinType::LeftAnti, |
2985 | 1 | ]; |
2986 | 1 | |
2987 | 1 | // Disable DiskManager to prevent spilling |
2988 | 1 | let runtime = RuntimeEnvBuilder::new() |
2989 | 1 | .with_memory_limit(100, 1.0) |
2990 | 1 | .with_disk_manager(DiskManagerConfig::Disabled) |
2991 | 1 | .build_arc()?0 ; |
2992 | 1 | let session_config = SessionConfig::default().with_batch_size(50); |
2993 | 1 | |
2994 | 7 | for join_type6 in join_types { |
2995 | 6 | let task_ctx = TaskContext::default() |
2996 | 6 | .with_session_config(session_config.clone()) |
2997 | 6 | .with_runtime(Arc::clone(&runtime)); |
2998 | 6 | let task_ctx = Arc::new(task_ctx); |
2999 | 6 | let join = join_with_options( |
3000 | 6 | Arc::clone(&left), |
3001 | 6 | Arc::clone(&right), |
3002 | 6 | on.clone(), |
3003 | 6 | join_type, |
3004 | 6 | sort_options.clone(), |
3005 | 6 | false, |
3006 | 6 | )?0 ; |
3007 | 1 | |
3008 | 6 | let stream = join.execute(0, task_ctx)?0 ; |
3009 | 6 | let err = common::collect(stream).await0 .unwrap_err(); |
3010 | 6 | |
3011 | 6 | assert_contains!(err.to_string(), "Failed to allocate additional"); |
3012 | 6 | assert_contains!(err.to_string(), "SMJStream[0]"); |
3013 | 6 | assert_contains!(err.to_string(), "Disk spilling disabled"); |
3014 | 6 | assert!(join.metrics().is_some()); |
3015 | 6 | assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); |
3016 | 6 | assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); |
3017 | 6 | assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); |
3018 | 1 | } |
3019 | 1 | |
3020 | 1 | Ok(()) |
3021 | 1 | } |
3022 | | |
3023 | | #[tokio::test] |
3024 | 1 | async fn overallocation_single_batch_spill() -> Result<()> { |
3025 | 1 | let left = build_table( |
3026 | 1 | ("a1", &vec![0, 1, 2, 3, 4, 5]), |
3027 | 1 | ("b1", &vec![1, 2, 3, 4, 5, 6]), |
3028 | 1 | ("c1", &vec![4, 5, 6, 7, 8, 9]), |
3029 | 1 | ); |
3030 | 1 | let right = build_table( |
3031 | 1 | ("a2", &vec![0, 10, 20, 30, 40]), |
3032 | 1 | ("b2", &vec![1, 3, 4, 6, 8]), |
3033 | 1 | ("c2", &vec![50, 60, 70, 80, 90]), |
3034 | 1 | ); |
3035 | 1 | let on = vec![( |
3036 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
3037 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
3038 | 1 | )]; |
3039 | 1 | let sort_options = vec![SortOptions::default(); on.len()]; |
3040 | 1 | |
3041 | 1 | let join_types = [ |
3042 | 1 | JoinType::Inner, |
3043 | 1 | JoinType::Left, |
3044 | 1 | JoinType::Right, |
3045 | 1 | JoinType::Full, |
3046 | 1 | JoinType::LeftSemi, |
3047 | 1 | JoinType::LeftAnti, |
3048 | 1 | ]; |
3049 | 1 | |
3050 | 1 | // Enable DiskManager to allow spilling |
3051 | 1 | let runtime = RuntimeEnvBuilder::new() |
3052 | 1 | .with_memory_limit(100, 1.0) |
3053 | 1 | .with_disk_manager(DiskManagerConfig::NewOs) |
3054 | 1 | .build_arc()?0 ; |
3055 | 1 | |
3056 | 3 | for batch_size2 in [1, 50] { |
3057 | 2 | let session_config = SessionConfig::default().with_batch_size(batch_size); |
3058 | 1 | |
3059 | 14 | for join_type12 in &join_types { |
3060 | 12 | let task_ctx = TaskContext::default() |
3061 | 12 | .with_session_config(session_config.clone()) |
3062 | 12 | .with_runtime(Arc::clone(&runtime)); |
3063 | 12 | let task_ctx = Arc::new(task_ctx); |
3064 | 1 | |
3065 | 12 | let join = join_with_options( |
3066 | 12 | Arc::clone(&left), |
3067 | 12 | Arc::clone(&right), |
3068 | 12 | on.clone(), |
3069 | 12 | *join_type, |
3070 | 12 | sort_options.clone(), |
3071 | 12 | false, |
3072 | 12 | )?0 ; |
3073 | 1 | |
3074 | 12 | let stream = join.execute(0, task_ctx)?0 ; |
3075 | 12 | let spilled_join_result = common::collect(stream).await0 .unwrap(); |
3076 | 12 | |
3077 | 12 | assert!(join.metrics().is_some()); |
3078 | 12 | assert!(join.metrics().unwrap().spill_count().unwrap() > 0); |
3079 | 12 | assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); |
3080 | 12 | assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); |
3081 | 1 | |
3082 | 1 | // Run the test with no spill configuration as |
3083 | 12 | let task_ctx_no_spill = |
3084 | 12 | TaskContext::default().with_session_config(session_config.clone()); |
3085 | 12 | let task_ctx_no_spill = Arc::new(task_ctx_no_spill); |
3086 | 1 | |
3087 | 12 | let join = join_with_options( |
3088 | 12 | Arc::clone(&left), |
3089 | 12 | Arc::clone(&right), |
3090 | 12 | on.clone(), |
3091 | 12 | *join_type, |
3092 | 12 | sort_options.clone(), |
3093 | 12 | false, |
3094 | 12 | )?0 ; |
3095 | 12 | let stream = join.execute(0, task_ctx_no_spill)?0 ; |
3096 | 12 | let no_spilled_join_result = common::collect(stream).await0 .unwrap(); |
3097 | 12 | |
3098 | 12 | assert!(join.metrics().is_some()); |
3099 | 12 | assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); |
3100 | 12 | assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); |
3101 | 12 | assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); |
3102 | 1 | // Compare spilled and non spilled data to check spill logic doesn't corrupt the data |
3103 | 12 | assert_eq!(spilled_join_result, no_spilled_join_result); |
3104 | 1 | } |
3105 | 1 | } |
3106 | 1 | |
3107 | 1 | Ok(()) |
3108 | 1 | } |
3109 | | |
3110 | | #[tokio::test] |
3111 | 1 | async fn overallocation_multi_batch_spill() -> Result<()> { |
3112 | 1 | let left_batch_1 = build_table_i32( |
3113 | 1 | ("a1", &vec![0, 1]), |
3114 | 1 | ("b1", &vec![1, 1]), |
3115 | 1 | ("c1", &vec![4, 5]), |
3116 | 1 | ); |
3117 | 1 | let left_batch_2 = build_table_i32( |
3118 | 1 | ("a1", &vec![2, 3]), |
3119 | 1 | ("b1", &vec![1, 1]), |
3120 | 1 | ("c1", &vec![6, 7]), |
3121 | 1 | ); |
3122 | 1 | let left_batch_3 = build_table_i32( |
3123 | 1 | ("a1", &vec![4, 5]), |
3124 | 1 | ("b1", &vec![1, 1]), |
3125 | 1 | ("c1", &vec![8, 9]), |
3126 | 1 | ); |
3127 | 1 | let right_batch_1 = build_table_i32( |
3128 | 1 | ("a2", &vec![0, 10]), |
3129 | 1 | ("b2", &vec![1, 1]), |
3130 | 1 | ("c2", &vec![50, 60]), |
3131 | 1 | ); |
3132 | 1 | let right_batch_2 = build_table_i32( |
3133 | 1 | ("a2", &vec![20, 30]), |
3134 | 1 | ("b2", &vec![1, 1]), |
3135 | 1 | ("c2", &vec![70, 80]), |
3136 | 1 | ); |
3137 | 1 | let right_batch_3 = |
3138 | 1 | build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); |
3139 | 1 | let left = |
3140 | 1 | build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); |
3141 | 1 | let right = |
3142 | 1 | build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); |
3143 | 1 | let on = vec![( |
3144 | 1 | Arc::new(Column::new_with_schema("b1", &left.schema())?0 ) as _, |
3145 | 1 | Arc::new(Column::new_with_schema("b2", &right.schema())?0 ) as _, |
3146 | 1 | )]; |
3147 | 1 | let sort_options = vec![SortOptions::default(); on.len()]; |
3148 | 1 | |
3149 | 1 | let join_types = [ |
3150 | 1 | JoinType::Inner, |
3151 | 1 | JoinType::Left, |
3152 | 1 | JoinType::Right, |
3153 | 1 | JoinType::Full, |
3154 | 1 | JoinType::LeftSemi, |
3155 | 1 | JoinType::LeftAnti, |
3156 | 1 | ]; |
3157 | 1 | |
3158 | 1 | // Enable DiskManager to allow spilling |
3159 | 1 | let runtime = RuntimeEnvBuilder::new() |
3160 | 1 | .with_memory_limit(500, 1.0) |
3161 | 1 | .with_disk_manager(DiskManagerConfig::NewOs) |
3162 | 1 | .build_arc()?0 ; |
3163 | 1 | |
3164 | 3 | for batch_size2 in [1, 50] { |
3165 | 2 | let session_config = SessionConfig::default().with_batch_size(batch_size); |
3166 | 1 | |
3167 | 14 | for join_type12 in &join_types { |
3168 | 12 | let task_ctx = TaskContext::default() |
3169 | 12 | .with_session_config(session_config.clone()) |
3170 | 12 | .with_runtime(Arc::clone(&runtime)); |
3171 | 12 | let task_ctx = Arc::new(task_ctx); |
3172 | 12 | let join = join_with_options( |
3173 | 12 | Arc::clone(&left), |
3174 | 12 | Arc::clone(&right), |
3175 | 12 | on.clone(), |
3176 | 12 | *join_type, |
3177 | 12 | sort_options.clone(), |
3178 | 12 | false, |
3179 | 12 | )?0 ; |
3180 | 1 | |
3181 | 12 | let stream = join.execute(0, task_ctx)?0 ; |
3182 | 12 | let spilled_join_result = common::collect(stream).await0 .unwrap(); |
3183 | 12 | assert!(join.metrics().is_some()); |
3184 | 12 | assert!(join.metrics().unwrap().spill_count().unwrap() > 0); |
3185 | 12 | assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); |
3186 | 12 | assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); |
3187 | 1 | |
3188 | 1 | // Run the test with no spill configuration as |
3189 | 12 | let task_ctx_no_spill = |
3190 | 12 | TaskContext::default().with_session_config(session_config.clone()); |
3191 | 12 | let task_ctx_no_spill = Arc::new(task_ctx_no_spill); |
3192 | 1 | |
3193 | 12 | let join = join_with_options( |
3194 | 12 | Arc::clone(&left), |
3195 | 12 | Arc::clone(&right), |
3196 | 12 | on.clone(), |
3197 | 12 | *join_type, |
3198 | 12 | sort_options.clone(), |
3199 | 12 | false, |
3200 | 12 | )?0 ; |
3201 | 12 | let stream = join.execute(0, task_ctx_no_spill)?0 ; |
3202 | 12 | let no_spilled_join_result = common::collect(stream).await0 .unwrap(); |
3203 | 12 | |
3204 | 12 | assert!(join.metrics().is_some()); |
3205 | 12 | assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); |
3206 | 12 | assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); |
3207 | 12 | assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); |
3208 | 1 | // Compare spilled and non spilled data to check spill logic doesn't corrupt the data |
3209 | 12 | assert_eq!(spilled_join_result, no_spilled_join_result); |
3210 | 1 | } |
3211 | 1 | } |
3212 | 1 | |
3213 | 1 | Ok(()) |
3214 | 1 | } |
3215 | | |
3216 | | #[tokio::test] |
3217 | 1 | async fn left_semi_join_filtered_mask() -> Result<()> { |
3218 | 1 | assert_eq!( |
3219 | 1 | get_filtered_join_mask( |
3220 | 1 | LeftSemi, |
3221 | 1 | &UInt64Array::from(vec![0, 0, 1, 1]), |
3222 | 1 | &BooleanArray::from(vec![true, true, false, false]), |
3223 | 1 | &HashSet::new(), |
3224 | 1 | &0, |
3225 | 1 | ), |
3226 | 1 | Some((BooleanArray::from(vec![true, false, false, false]), vec![0])) |
3227 | 1 | ); |
3228 | 1 | |
3229 | 1 | assert_eq!( |
3230 | 1 | get_filtered_join_mask( |
3231 | 1 | LeftSemi, |
3232 | 1 | &UInt64Array::from(vec![0, 1]), |
3233 | 1 | &BooleanArray::from(vec![true, true]), |
3234 | 1 | &HashSet::new(), |
3235 | 1 | &0, |
3236 | 1 | ), |
3237 | 1 | Some((BooleanArray::from(vec![true, true]), vec![0, 1])) |
3238 | 1 | ); |
3239 | 1 | |
3240 | 1 | assert_eq!( |
3241 | 1 | get_filtered_join_mask( |
3242 | 1 | LeftSemi, |
3243 | 1 | &UInt64Array::from(vec![0, 1]), |
3244 | 1 | &BooleanArray::from(vec![false, true]), |
3245 | 1 | &HashSet::new(), |
3246 | 1 | &0, |
3247 | 1 | ), |
3248 | 1 | Some((BooleanArray::from(vec![false, true]), vec![1])) |
3249 | 1 | ); |
3250 | 1 | |
3251 | 1 | assert_eq!( |
3252 | 1 | get_filtered_join_mask( |
3253 | 1 | LeftSemi, |
3254 | 1 | &UInt64Array::from(vec![0, 1]), |
3255 | 1 | &BooleanArray::from(vec![true, false]), |
3256 | 1 | &HashSet::new(), |
3257 | 1 | &0, |
3258 | 1 | ), |
3259 | 1 | Some((BooleanArray::from(vec![true, false]), vec![0])) |
3260 | 1 | ); |
3261 | 1 | |
3262 | 1 | assert_eq!( |
3263 | 1 | get_filtered_join_mask( |
3264 | 1 | LeftSemi, |
3265 | 1 | &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), |
3266 | 1 | &BooleanArray::from(vec![false, true, true, true, true, true]), |
3267 | 1 | &HashSet::new(), |
3268 | 1 | &0, |
3269 | 1 | ), |
3270 | 1 | Some(( |
3271 | 1 | BooleanArray::from(vec![false, true, false, true, false, false]), |
3272 | 1 | vec![0, 1] |
3273 | 1 | )) |
3274 | 1 | ); |
3275 | 1 | |
3276 | 1 | assert_eq!( |
3277 | 1 | get_filtered_join_mask( |
3278 | 1 | LeftSemi, |
3279 | 1 | &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), |
3280 | 1 | &BooleanArray::from(vec![false, false, false, false, false, true]), |
3281 | 1 | &HashSet::new(), |
3282 | 1 | &0, |
3283 | 1 | ), |
3284 | 1 | Some(( |
3285 | 1 | BooleanArray::from(vec![false, false, false, false, false, true]), |
3286 | 1 | vec![1] |
3287 | 1 | )) |
3288 | 1 | ); |
3289 | 1 | |
3290 | 1 | assert_eq!( |
3291 | 1 | get_filtered_join_mask( |
3292 | 1 | LeftSemi, |
3293 | 1 | &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), |
3294 | 1 | &BooleanArray::from(vec![true, false, false, false, false, true]), |
3295 | 1 | &HashSet::from_iter(vec![1]), |
3296 | 1 | &0, |
3297 | 1 | ), |
3298 | 1 | Some(( |
3299 | 1 | BooleanArray::from(vec![true, false, false, false, false, false]), |
3300 | 1 | vec![0] |
3301 | 1 | )) |
3302 | 1 | ); |
3303 | 1 | |
3304 | 1 | Ok(()) |
3305 | 1 | } |
3306 | | |
3307 | | #[tokio::test] |
3308 | 1 | async fn left_anti_join_filtered_mask() -> Result<()> { |
3309 | 1 | assert_eq!( |
3310 | 1 | get_filtered_join_mask( |
3311 | 1 | LeftAnti, |
3312 | 1 | &UInt64Array::from(vec![0, 0, 1, 1]), |
3313 | 1 | &BooleanArray::from(vec![true, true, false, false]), |
3314 | 1 | &HashSet::new(), |
3315 | 1 | &0, |
3316 | 1 | ), |
3317 | 1 | Some((BooleanArray::from(vec![false, false, false, true]), vec![0])) |
3318 | 1 | ); |
3319 | 1 | |
3320 | 1 | assert_eq!( |
3321 | 1 | get_filtered_join_mask( |
3322 | 1 | LeftAnti, |
3323 | 1 | &UInt64Array::from(vec![0, 1]), |
3324 | 1 | &BooleanArray::from(vec![true, true]), |
3325 | 1 | &HashSet::new(), |
3326 | 1 | &0, |
3327 | 1 | ), |
3328 | 1 | Some((BooleanArray::from(vec![false, false]), vec![0, 1])) |
3329 | 1 | ); |
3330 | 1 | |
3331 | 1 | assert_eq!( |
3332 | 1 | get_filtered_join_mask( |
3333 | 1 | LeftAnti, |
3334 | 1 | &UInt64Array::from(vec![0, 1]), |
3335 | 1 | &BooleanArray::from(vec![false, true]), |
3336 | 1 | &HashSet::new(), |
3337 | 1 | &0, |
3338 | 1 | ), |
3339 | 1 | Some((BooleanArray::from(vec![true, false]), vec![1])) |
3340 | 1 | ); |
3341 | 1 | |
3342 | 1 | assert_eq!( |
3343 | 1 | get_filtered_join_mask( |
3344 | 1 | LeftAnti, |
3345 | 1 | &UInt64Array::from(vec![0, 1]), |
3346 | 1 | &BooleanArray::from(vec![true, false]), |
3347 | 1 | &HashSet::new(), |
3348 | 1 | &0, |
3349 | 1 | ), |
3350 | 1 | Some((BooleanArray::from(vec![false, true]), vec![0])) |
3351 | 1 | ); |
3352 | 1 | |
3353 | 1 | assert_eq!( |
3354 | 1 | get_filtered_join_mask( |
3355 | 1 | LeftAnti, |
3356 | 1 | &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), |
3357 | 1 | &BooleanArray::from(vec![false, true, true, true, true, true]), |
3358 | 1 | &HashSet::new(), |
3359 | 1 | &0, |
3360 | 1 | ), |
3361 | 1 | Some(( |
3362 | 1 | BooleanArray::from(vec![false, false, false, false, false, false]), |
3363 | 1 | vec![0, 1] |
3364 | 1 | )) |
3365 | 1 | ); |
3366 | 1 | |
3367 | 1 | assert_eq!( |
3368 | 1 | get_filtered_join_mask( |
3369 | 1 | LeftAnti, |
3370 | 1 | &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), |
3371 | 1 | &BooleanArray::from(vec![false, false, false, false, false, true]), |
3372 | 1 | &HashSet::new(), |
3373 | 1 | &0, |
3374 | 1 | ), |
3375 | 1 | Some(( |
3376 | 1 | BooleanArray::from(vec![false, false, true, false, false, false]), |
3377 | 1 | vec![1] |
3378 | 1 | )) |
3379 | 1 | ); |
3380 | 1 | |
3381 | 1 | Ok(()) |
3382 | 1 | } |
3383 | | |
3384 | | /// Returns the column names on the schema |
3385 | 19 | fn columns(schema: &Schema) -> Vec<String> { |
3386 | 108 | schema.fields().iter().map(|f| f.name().clone()).collect() |
3387 | 19 | } |
3388 | | } |