/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/cross_join.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines the cross join plan for loading the left side of the cross join |
19 | | //! and producing batches in parallel for the right partitions |
20 | | |
21 | | use super::utils::{ |
22 | | adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, |
23 | | StatefulStreamResult, |
24 | | }; |
25 | | use crate::coalesce_partitions::CoalescePartitionsExec; |
26 | | use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; |
27 | | use crate::{ |
28 | | execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs, |
29 | | DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, |
30 | | ExecutionPlanProperties, PlanProperties, RecordBatchStream, |
31 | | SendableRecordBatchStream, Statistics, |
32 | | }; |
33 | | use arrow::compute::concat_batches; |
34 | | use std::{any::Any, sync::Arc, task::Poll}; |
35 | | |
36 | | use arrow::datatypes::{Fields, Schema, SchemaRef}; |
37 | | use arrow::record_batch::RecordBatch; |
38 | | use arrow_array::RecordBatchOptions; |
39 | | use datafusion_common::stats::Precision; |
40 | | use datafusion_common::{internal_err, JoinType, Result, ScalarValue}; |
41 | | use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; |
42 | | use datafusion_execution::TaskContext; |
43 | | use datafusion_physical_expr::equivalence::join_equivalence_properties; |
44 | | |
45 | | use async_trait::async_trait; |
46 | | use futures::{ready, Stream, StreamExt, TryStreamExt}; |
47 | | |
48 | | /// Data of the left side |
49 | | type JoinLeftData = (RecordBatch, MemoryReservation); |
50 | | |
51 | | /// executes partitions in parallel and combines them into a set of |
52 | | /// partitions by combining all values from the left with all values on the right |
53 | | #[derive(Debug)] |
54 | | pub struct CrossJoinExec { |
55 | | /// left (build) side which gets loaded in memory |
56 | | pub left: Arc<dyn ExecutionPlan>, |
57 | | /// right (probe) side which are combined with left side |
58 | | pub right: Arc<dyn ExecutionPlan>, |
59 | | /// The schema once the join is applied |
60 | | schema: SchemaRef, |
61 | | /// Build-side data |
62 | | left_fut: OnceAsync<JoinLeftData>, |
63 | | /// Execution plan metrics |
64 | | metrics: ExecutionPlanMetricsSet, |
65 | | cache: PlanProperties, |
66 | | } |
67 | | |
68 | | impl CrossJoinExec { |
69 | | /// Create a new [CrossJoinExec]. |
70 | 2 | pub fn new(left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>) -> Self { |
71 | 2 | // left then right |
72 | 2 | let (all_columns, metadata) = { |
73 | 2 | let left_schema = left.schema(); |
74 | 2 | let right_schema = right.schema(); |
75 | 2 | let left_fields = left_schema.fields().iter(); |
76 | 2 | let right_fields = right_schema.fields().iter(); |
77 | 2 | |
78 | 2 | let mut metadata = left_schema.metadata().clone(); |
79 | 2 | metadata.extend(right_schema.metadata().clone()); |
80 | 2 | |
81 | 2 | ( |
82 | 2 | left_fields.chain(right_fields).cloned().collect::<Fields>(), |
83 | 2 | metadata, |
84 | 2 | ) |
85 | 2 | }; |
86 | 2 | |
87 | 2 | let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); |
88 | 2 | let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); |
89 | 2 | CrossJoinExec { |
90 | 2 | left, |
91 | 2 | right, |
92 | 2 | schema, |
93 | 2 | left_fut: Default::default(), |
94 | 2 | metrics: ExecutionPlanMetricsSet::default(), |
95 | 2 | cache, |
96 | 2 | } |
97 | 2 | } |
98 | | |
99 | | /// left (build) side which gets loaded in memory |
100 | 2 | pub fn left(&self) -> &Arc<dyn ExecutionPlan> { |
101 | 2 | &self.left |
102 | 2 | } |
103 | | |
104 | | /// right side which gets combined with left side |
105 | 0 | pub fn right(&self) -> &Arc<dyn ExecutionPlan> { |
106 | 0 | &self.right |
107 | 0 | } |
108 | | |
109 | | /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. |
110 | 2 | fn compute_properties( |
111 | 2 | left: &Arc<dyn ExecutionPlan>, |
112 | 2 | right: &Arc<dyn ExecutionPlan>, |
113 | 2 | schema: SchemaRef, |
114 | 2 | ) -> PlanProperties { |
115 | 2 | // Calculate equivalence properties |
116 | 2 | // TODO: Check equivalence properties of cross join, it may preserve |
117 | 2 | // ordering in some cases. |
118 | 2 | let eq_properties = join_equivalence_properties( |
119 | 2 | left.equivalence_properties().clone(), |
120 | 2 | right.equivalence_properties().clone(), |
121 | 2 | &JoinType::Full, |
122 | 2 | schema, |
123 | 2 | &[false, false], |
124 | 2 | None, |
125 | 2 | &[], |
126 | 2 | ); |
127 | 2 | |
128 | 2 | // Get output partitioning: |
129 | 2 | // TODO: Optimize the cross join implementation to generate M * N |
130 | 2 | // partitions. |
131 | 2 | let output_partitioning = adjust_right_output_partitioning( |
132 | 2 | right.output_partitioning(), |
133 | 2 | left.schema().fields.len(), |
134 | 2 | ); |
135 | 2 | |
136 | 2 | // Determine the execution mode: |
137 | 2 | let mut mode = execution_mode_from_children([left, right]); |
138 | 2 | if mode.is_unbounded() { |
139 | 0 | // If any of the inputs is unbounded, cross join breaks the pipeline. |
140 | 0 | mode = ExecutionMode::PipelineBreaking; |
141 | 2 | } |
142 | | |
143 | 2 | PlanProperties::new(eq_properties, output_partitioning, mode) |
144 | 2 | } |
145 | | } |
146 | | |
147 | | /// Asynchronously collect the result of the left child |
148 | 2 | async fn load_left_input( |
149 | 2 | left: Arc<dyn ExecutionPlan>, |
150 | 2 | context: Arc<TaskContext>, |
151 | 2 | metrics: BuildProbeJoinMetrics, |
152 | 2 | reservation: MemoryReservation, |
153 | 2 | ) -> Result<JoinLeftData> { |
154 | 2 | // merge all left parts into a single stream |
155 | 2 | let left_schema = left.schema(); |
156 | 2 | let merge = if left.output_partitioning().partition_count() != 1 { |
157 | 0 | Arc::new(CoalescePartitionsExec::new(left)) |
158 | | } else { |
159 | 2 | left |
160 | | }; |
161 | 2 | let stream = merge.execute(0, context)?0 ; |
162 | | |
163 | | // Load all batches and count the rows |
164 | 2 | let (batches, _metrics, reservation1 ) = stream |
165 | 2 | .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { |
166 | 2 | let batch_size = batch.get_array_memory_size(); |
167 | 2 | // Reserve memory for incoming batch |
168 | 2 | acc.2.try_grow(batch_size)?1 ; |
169 | | // Update metrics |
170 | 1 | acc.1.build_mem_used.add(batch_size); |
171 | 1 | acc.1.build_input_batches.add(1); |
172 | 1 | acc.1.build_input_rows.add(batch.num_rows()); |
173 | 1 | // Push batch to output |
174 | 1 | acc.0.push(batch); |
175 | 1 | Ok(acc) |
176 | 4 | })2 |
177 | 1 | .await0 ?; |
178 | | |
179 | 1 | let merged_batch = concat_batches(&left_schema, &batches)?0 ; |
180 | | |
181 | 1 | Ok((merged_batch, reservation)) |
182 | 2 | } |
183 | | |
184 | | impl DisplayAs for CrossJoinExec { |
185 | 0 | fn fmt_as( |
186 | 0 | &self, |
187 | 0 | t: DisplayFormatType, |
188 | 0 | f: &mut std::fmt::Formatter, |
189 | 0 | ) -> std::fmt::Result { |
190 | 0 | match t { |
191 | | DisplayFormatType::Default | DisplayFormatType::Verbose => { |
192 | 0 | write!(f, "CrossJoinExec") |
193 | 0 | } |
194 | 0 | } |
195 | 0 | } |
196 | | } |
197 | | |
198 | | impl ExecutionPlan for CrossJoinExec { |
199 | 0 | fn name(&self) -> &'static str { |
200 | 0 | "CrossJoinExec" |
201 | 0 | } |
202 | | |
203 | 0 | fn as_any(&self) -> &dyn Any { |
204 | 0 | self |
205 | 0 | } |
206 | | |
207 | 2 | fn properties(&self) -> &PlanProperties { |
208 | 2 | &self.cache |
209 | 2 | } |
210 | | |
211 | 0 | fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { |
212 | 0 | vec![&self.left, &self.right] |
213 | 0 | } |
214 | | |
215 | 0 | fn metrics(&self) -> Option<MetricsSet> { |
216 | 0 | Some(self.metrics.clone_inner()) |
217 | 0 | } |
218 | | |
219 | 0 | fn with_new_children( |
220 | 0 | self: Arc<Self>, |
221 | 0 | children: Vec<Arc<dyn ExecutionPlan>>, |
222 | 0 | ) -> Result<Arc<dyn ExecutionPlan>> { |
223 | 0 | Ok(Arc::new(CrossJoinExec::new( |
224 | 0 | Arc::clone(&children[0]), |
225 | 0 | Arc::clone(&children[1]), |
226 | 0 | ))) |
227 | 0 | } |
228 | | |
229 | 0 | fn required_input_distribution(&self) -> Vec<Distribution> { |
230 | 0 | vec![ |
231 | 0 | Distribution::SinglePartition, |
232 | 0 | Distribution::UnspecifiedDistribution, |
233 | 0 | ] |
234 | 0 | } |
235 | | |
236 | 2 | fn execute( |
237 | 2 | &self, |
238 | 2 | partition: usize, |
239 | 2 | context: Arc<TaskContext>, |
240 | 2 | ) -> Result<SendableRecordBatchStream> { |
241 | 2 | let stream = self.right.execute(partition, Arc::clone(&context))?0 ; |
242 | | |
243 | 2 | let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); |
244 | 2 | |
245 | 2 | // Initialization of operator-level reservation |
246 | 2 | let reservation = |
247 | 2 | MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); |
248 | 2 | |
249 | 2 | let left_fut = self.left_fut.once(|| { |
250 | 2 | load_left_input( |
251 | 2 | Arc::clone(&self.left), |
252 | 2 | context, |
253 | 2 | join_metrics.clone(), |
254 | 2 | reservation, |
255 | 2 | ) |
256 | 2 | }); |
257 | 2 | |
258 | 2 | Ok(Box::pin(CrossJoinStream { |
259 | 2 | schema: Arc::clone(&self.schema), |
260 | 2 | left_fut, |
261 | 2 | right: stream, |
262 | 2 | left_index: 0, |
263 | 2 | join_metrics, |
264 | 2 | state: CrossJoinStreamState::WaitBuildSide, |
265 | 2 | left_data: RecordBatch::new_empty(self.left().schema()), |
266 | 2 | })) |
267 | 2 | } |
268 | | |
269 | 0 | fn statistics(&self) -> Result<Statistics> { |
270 | 0 | Ok(stats_cartesian_product( |
271 | 0 | self.left.statistics()?, |
272 | 0 | self.right.statistics()?, |
273 | | )) |
274 | 0 | } |
275 | | } |
276 | | |
277 | | /// [left/right]_col_count are required in case the column statistics are None |
278 | 2 | fn stats_cartesian_product( |
279 | 2 | left_stats: Statistics, |
280 | 2 | right_stats: Statistics, |
281 | 2 | ) -> Statistics { |
282 | 2 | let left_row_count = left_stats.num_rows; |
283 | 2 | let right_row_count = right_stats.num_rows; |
284 | 2 | |
285 | 2 | // calculate global stats |
286 | 2 | let num_rows = left_row_count.multiply(&right_row_count); |
287 | 2 | // the result size is two times a*b because you have the columns of both left and right |
288 | 2 | let total_byte_size = left_stats |
289 | 2 | .total_byte_size |
290 | 2 | .multiply(&right_stats.total_byte_size) |
291 | 2 | .multiply(&Precision::Exact(2)); |
292 | 2 | |
293 | 2 | let left_col_stats = left_stats.column_statistics; |
294 | 2 | let right_col_stats = right_stats.column_statistics; |
295 | 2 | |
296 | 2 | // the null counts must be multiplied by the row counts of the other side (if defined) |
297 | 2 | // Min, max and distinct_count on the other hand are invariants. |
298 | 2 | let cross_join_stats = left_col_stats |
299 | 2 | .into_iter() |
300 | 4 | .map(|s| ColumnStatistics { |
301 | 4 | null_count: s.null_count.multiply(&right_row_count), |
302 | 4 | distinct_count: s.distinct_count, |
303 | 4 | min_value: s.min_value, |
304 | 4 | max_value: s.max_value, |
305 | 4 | }) |
306 | 2 | .chain(right_col_stats.into_iter().map(|s| ColumnStatistics { |
307 | 2 | null_count: s.null_count.multiply(&left_row_count), |
308 | 2 | distinct_count: s.distinct_count, |
309 | 2 | min_value: s.min_value, |
310 | 2 | max_value: s.max_value, |
311 | 2 | })) |
312 | 2 | .collect(); |
313 | 2 | |
314 | 2 | Statistics { |
315 | 2 | num_rows, |
316 | 2 | total_byte_size, |
317 | 2 | column_statistics: cross_join_stats, |
318 | 2 | } |
319 | 2 | } |
320 | | |
321 | | /// A stream that issues [RecordBatch]es as they arrive from the right of the join. |
322 | | struct CrossJoinStream { |
323 | | /// Input schema |
324 | | schema: Arc<Schema>, |
325 | | /// Future for data from left side |
326 | | left_fut: OnceFut<JoinLeftData>, |
327 | | /// Right side stream |
328 | | right: SendableRecordBatchStream, |
329 | | /// Current value on the left |
330 | | left_index: usize, |
331 | | /// Join execution metrics |
332 | | join_metrics: BuildProbeJoinMetrics, |
333 | | /// State of the stream |
334 | | state: CrossJoinStreamState, |
335 | | /// Left data |
336 | | left_data: RecordBatch, |
337 | | } |
338 | | |
339 | | impl RecordBatchStream for CrossJoinStream { |
340 | 0 | fn schema(&self) -> SchemaRef { |
341 | 0 | Arc::clone(&self.schema) |
342 | 0 | } |
343 | | } |
344 | | |
345 | | /// Represents states of CrossJoinStream |
346 | | enum CrossJoinStreamState { |
347 | | WaitBuildSide, |
348 | | FetchProbeBatch, |
349 | | /// Holds the currently processed right side batch |
350 | | BuildBatches(RecordBatch), |
351 | | } |
352 | | |
353 | | impl CrossJoinStreamState { |
354 | | /// Tries to extract RecordBatch from CrossJoinStreamState enum. |
355 | | /// Returns an error if state is not BuildBatches state. |
356 | 4 | fn try_as_record_batch(&mut self) -> Result<&RecordBatch> { |
357 | 4 | match self { |
358 | 4 | CrossJoinStreamState::BuildBatches(rb) => Ok(rb), |
359 | 0 | _ => internal_err!("Expected RecordBatch in BuildBatches state"), |
360 | | } |
361 | 4 | } |
362 | | } |
363 | | |
364 | 3 | fn build_batch( |
365 | 3 | left_index: usize, |
366 | 3 | batch: &RecordBatch, |
367 | 3 | left_data: &RecordBatch, |
368 | 3 | schema: &Schema, |
369 | 3 | ) -> Result<RecordBatch> { |
370 | | // Repeat value on the left n times |
371 | 3 | let arrays = left_data |
372 | 3 | .columns() |
373 | 3 | .iter() |
374 | 9 | .map(|arr| { |
375 | 9 | let scalar = ScalarValue::try_from_array(arr, left_index)?0 ; |
376 | 9 | scalar.to_array_of_size(batch.num_rows()) |
377 | 9 | }) |
378 | 3 | .collect::<Result<Vec<_>>>()?0 ; |
379 | | |
380 | 3 | RecordBatch::try_new_with_options( |
381 | 3 | Arc::new(schema.clone()), |
382 | 3 | arrays |
383 | 3 | .iter() |
384 | 3 | .chain(batch.columns().iter()) |
385 | 3 | .cloned() |
386 | 3 | .collect(), |
387 | 3 | &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), |
388 | 3 | ) |
389 | 3 | .map_err(Into::into) |
390 | 3 | } |
391 | | |
392 | | #[async_trait] |
393 | | impl Stream for CrossJoinStream { |
394 | | type Item = Result<RecordBatch>; |
395 | | |
396 | 5 | fn poll_next( |
397 | 5 | mut self: std::pin::Pin<&mut Self>, |
398 | 5 | cx: &mut std::task::Context<'_>, |
399 | 5 | ) -> std::task::Poll<Option<Self::Item>> { |
400 | 5 | self.poll_next_impl(cx) |
401 | 5 | } |
402 | | } |
403 | | |
404 | | impl CrossJoinStream { |
405 | | /// Separate implementation function that unpins the [`CrossJoinStream`] so |
406 | | /// that partial borrows work correctly |
407 | 5 | fn poll_next_impl( |
408 | 5 | &mut self, |
409 | 5 | cx: &mut std::task::Context<'_>, |
410 | 5 | ) -> std::task::Poll<Option<Result<RecordBatch>>> { |
411 | | loop { |
412 | 8 | return match self.state { |
413 | | CrossJoinStreamState::WaitBuildSide => { |
414 | 2 | handle_state!1 (ready!0 (self.collect_build_side(cx))) |
415 | | } |
416 | | CrossJoinStreamState::FetchProbeBatch => { |
417 | 2 | handle_state!0 (ready!0 (self.fetch_probe_batch(cx))) |
418 | | } |
419 | | CrossJoinStreamState::BuildBatches(_) => { |
420 | 4 | handle_state!0 (self.build_batches()) |
421 | | } |
422 | | }; |
423 | | } |
424 | 5 | } |
425 | | |
426 | | /// Collects build (left) side of the join into the state. In case of an empty build batch, |
427 | | /// the execution terminates. Otherwise, the state is updated to fetch probe (right) batch. |
428 | 2 | fn collect_build_side( |
429 | 2 | &mut self, |
430 | 2 | cx: &mut std::task::Context<'_>, |
431 | 2 | ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { |
432 | 2 | let build_timer = self.join_metrics.build_time.timer(); |
433 | 2 | let (left_data1 , _) = match ready!0 (self.left_fut.get(cx)) { |
434 | 1 | Ok(left_data) => left_data, |
435 | 1 | Err(e) => return Poll::Ready(Err(e)), |
436 | | }; |
437 | 1 | build_timer.done(); |
438 | | |
439 | 1 | let result = if left_data.num_rows() == 0 { |
440 | 0 | StatefulStreamResult::Ready(None) |
441 | | } else { |
442 | 1 | self.left_data = left_data.clone(); |
443 | 1 | self.state = CrossJoinStreamState::FetchProbeBatch; |
444 | 1 | StatefulStreamResult::Continue |
445 | | }; |
446 | 1 | Poll::Ready(Ok(result)) |
447 | 2 | } |
448 | | |
449 | | /// Fetches the probe (right) batch, updates the metrics, and save the batch in the state. |
450 | | /// Then, the state is updated to build result batches. |
451 | 2 | fn fetch_probe_batch( |
452 | 2 | &mut self, |
453 | 2 | cx: &mut std::task::Context<'_>, |
454 | 2 | ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { |
455 | 2 | self.left_index = 0; |
456 | 2 | let right_data1 = match ready!0 (self.right.poll_next_unpin(cx)) { |
457 | 1 | Some(Ok(right_data)) => right_data, |
458 | 0 | Some(Err(e)) => return Poll::Ready(Err(e)), |
459 | 1 | None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))), |
460 | | }; |
461 | 1 | self.join_metrics.input_batches.add(1); |
462 | 1 | self.join_metrics.input_rows.add(right_data.num_rows()); |
463 | 1 | |
464 | 1 | self.state = CrossJoinStreamState::BuildBatches(right_data); |
465 | 1 | Poll::Ready(Ok(StatefulStreamResult::Continue)) |
466 | 2 | } |
467 | | |
468 | | /// Joins the the indexed row of left data with the current probe batch. |
469 | | /// If all the results are produced, the state is set to fetch new probe batch. |
470 | 4 | fn build_batches(&mut self) -> Result<StatefulStreamResult<Option<RecordBatch>>> { |
471 | 4 | let right_batch = self.state.try_as_record_batch()?0 ; |
472 | 4 | if self.left_index < self.left_data.num_rows() { |
473 | 3 | let join_timer = self.join_metrics.join_time.timer(); |
474 | 3 | let result = |
475 | 3 | build_batch(self.left_index, right_batch, &self.left_data, &self.schema); |
476 | 3 | join_timer.done(); |
477 | | |
478 | 3 | if let Ok(ref batch) = result { |
479 | 3 | self.join_metrics.output_batches.add(1); |
480 | 3 | self.join_metrics.output_rows.add(batch.num_rows()); |
481 | 3 | }0 |
482 | 3 | self.left_index += 1; |
483 | 3 | result.map(|r| StatefulStreamResult::Ready(Some(r))) |
484 | | } else { |
485 | 1 | self.state = CrossJoinStreamState::FetchProbeBatch; |
486 | 1 | Ok(StatefulStreamResult::Continue) |
487 | | } |
488 | 4 | } |
489 | | } |
490 | | |
491 | | #[cfg(test)] |
492 | | mod tests { |
493 | | use super::*; |
494 | | use crate::common; |
495 | | use crate::test::build_table_scan_i32; |
496 | | |
497 | | use datafusion_common::{assert_batches_sorted_eq, assert_contains}; |
498 | | use datafusion_execution::runtime_env::RuntimeEnvBuilder; |
499 | | |
500 | 2 | async fn join_collect( |
501 | 2 | left: Arc<dyn ExecutionPlan>, |
502 | 2 | right: Arc<dyn ExecutionPlan>, |
503 | 2 | context: Arc<TaskContext>, |
504 | 2 | ) -> Result<(Vec<String>, Vec<RecordBatch>)> { |
505 | 2 | let join = CrossJoinExec::new(left, right); |
506 | 2 | let columns_header = columns(&join.schema()); |
507 | | |
508 | 2 | let stream = join.execute(0, context)?0 ; |
509 | 2 | let batches1 = common::collect(stream).await0 ?1 ; |
510 | | |
511 | 1 | Ok((columns_header, batches)) |
512 | 2 | } |
513 | | |
514 | | #[tokio::test] |
515 | 1 | async fn test_stats_cartesian_product() { |
516 | 1 | let left_row_count = 11; |
517 | 1 | let left_bytes = 23; |
518 | 1 | let right_row_count = 7; |
519 | 1 | let right_bytes = 27; |
520 | 1 | |
521 | 1 | let left = Statistics { |
522 | 1 | num_rows: Precision::Exact(left_row_count), |
523 | 1 | total_byte_size: Precision::Exact(left_bytes), |
524 | 1 | column_statistics: vec![ |
525 | 1 | ColumnStatistics { |
526 | 1 | distinct_count: Precision::Exact(5), |
527 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(21))), |
528 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), |
529 | 1 | null_count: Precision::Exact(0), |
530 | 1 | }, |
531 | 1 | ColumnStatistics { |
532 | 1 | distinct_count: Precision::Exact(1), |
533 | 1 | max_value: Precision::Exact(ScalarValue::from("x")), |
534 | 1 | min_value: Precision::Exact(ScalarValue::from("a")), |
535 | 1 | null_count: Precision::Exact(3), |
536 | 1 | }, |
537 | 1 | ], |
538 | 1 | }; |
539 | 1 | |
540 | 1 | let right = Statistics { |
541 | 1 | num_rows: Precision::Exact(right_row_count), |
542 | 1 | total_byte_size: Precision::Exact(right_bytes), |
543 | 1 | column_statistics: vec![ColumnStatistics { |
544 | 1 | distinct_count: Precision::Exact(3), |
545 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(12))), |
546 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(0))), |
547 | 1 | null_count: Precision::Exact(2), |
548 | 1 | }], |
549 | 1 | }; |
550 | 1 | |
551 | 1 | let result = stats_cartesian_product(left, right); |
552 | 1 | |
553 | 1 | let expected = Statistics { |
554 | 1 | num_rows: Precision::Exact(left_row_count * right_row_count), |
555 | 1 | total_byte_size: Precision::Exact(2 * left_bytes * right_bytes), |
556 | 1 | column_statistics: vec![ |
557 | 1 | ColumnStatistics { |
558 | 1 | distinct_count: Precision::Exact(5), |
559 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(21))), |
560 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), |
561 | 1 | null_count: Precision::Exact(0), |
562 | 1 | }, |
563 | 1 | ColumnStatistics { |
564 | 1 | distinct_count: Precision::Exact(1), |
565 | 1 | max_value: Precision::Exact(ScalarValue::from("x")), |
566 | 1 | min_value: Precision::Exact(ScalarValue::from("a")), |
567 | 1 | null_count: Precision::Exact(3 * right_row_count), |
568 | 1 | }, |
569 | 1 | ColumnStatistics { |
570 | 1 | distinct_count: Precision::Exact(3), |
571 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(12))), |
572 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(0))), |
573 | 1 | null_count: Precision::Exact(2 * left_row_count), |
574 | 1 | }, |
575 | 1 | ], |
576 | 1 | }; |
577 | 1 | |
578 | 1 | assert_eq!(result, expected); |
579 | 1 | } |
580 | | |
581 | | #[tokio::test] |
582 | 1 | async fn test_stats_cartesian_product_with_unknown_size() { |
583 | 1 | let left_row_count = 11; |
584 | 1 | |
585 | 1 | let left = Statistics { |
586 | 1 | num_rows: Precision::Exact(left_row_count), |
587 | 1 | total_byte_size: Precision::Exact(23), |
588 | 1 | column_statistics: vec![ |
589 | 1 | ColumnStatistics { |
590 | 1 | distinct_count: Precision::Exact(5), |
591 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(21))), |
592 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), |
593 | 1 | null_count: Precision::Exact(0), |
594 | 1 | }, |
595 | 1 | ColumnStatistics { |
596 | 1 | distinct_count: Precision::Exact(1), |
597 | 1 | max_value: Precision::Exact(ScalarValue::from("x")), |
598 | 1 | min_value: Precision::Exact(ScalarValue::from("a")), |
599 | 1 | null_count: Precision::Exact(3), |
600 | 1 | }, |
601 | 1 | ], |
602 | 1 | }; |
603 | 1 | |
604 | 1 | let right = Statistics { |
605 | 1 | num_rows: Precision::Absent, |
606 | 1 | total_byte_size: Precision::Absent, |
607 | 1 | column_statistics: vec![ColumnStatistics { |
608 | 1 | distinct_count: Precision::Exact(3), |
609 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(12))), |
610 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(0))), |
611 | 1 | null_count: Precision::Exact(2), |
612 | 1 | }], |
613 | 1 | }; |
614 | 1 | |
615 | 1 | let result = stats_cartesian_product(left, right); |
616 | 1 | |
617 | 1 | let expected = Statistics { |
618 | 1 | num_rows: Precision::Absent, |
619 | 1 | total_byte_size: Precision::Absent, |
620 | 1 | column_statistics: vec![ |
621 | 1 | ColumnStatistics { |
622 | 1 | distinct_count: Precision::Exact(5), |
623 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(21))), |
624 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), |
625 | 1 | null_count: Precision::Absent, // we don't know the row count on the right |
626 | 1 | }, |
627 | 1 | ColumnStatistics { |
628 | 1 | distinct_count: Precision::Exact(1), |
629 | 1 | max_value: Precision::Exact(ScalarValue::from("x")), |
630 | 1 | min_value: Precision::Exact(ScalarValue::from("a")), |
631 | 1 | null_count: Precision::Absent, // we don't know the row count on the right |
632 | 1 | }, |
633 | 1 | ColumnStatistics { |
634 | 1 | distinct_count: Precision::Exact(3), |
635 | 1 | max_value: Precision::Exact(ScalarValue::Int64(Some(12))), |
636 | 1 | min_value: Precision::Exact(ScalarValue::Int64(Some(0))), |
637 | 1 | null_count: Precision::Exact(2 * left_row_count), |
638 | 1 | }, |
639 | 1 | ], |
640 | 1 | }; |
641 | 1 | |
642 | 1 | assert_eq!(result, expected); |
643 | 1 | } |
644 | | |
645 | | #[tokio::test] |
646 | 1 | async fn test_join() -> Result<()> { |
647 | 1 | let task_ctx = Arc::new(TaskContext::default()); |
648 | 1 | |
649 | 1 | let left = build_table_scan_i32( |
650 | 1 | ("a1", &vec![1, 2, 3]), |
651 | 1 | ("b1", &vec![4, 5, 6]), |
652 | 1 | ("c1", &vec![7, 8, 9]), |
653 | 1 | ); |
654 | 1 | let right = build_table_scan_i32( |
655 | 1 | ("a2", &vec![10, 11]), |
656 | 1 | ("b2", &vec![12, 13]), |
657 | 1 | ("c2", &vec![14, 15]), |
658 | 1 | ); |
659 | 1 | |
660 | 1 | let (columns, batches) = join_collect(left, right, task_ctx).await0 ?0 ; |
661 | 1 | |
662 | 1 | assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); |
663 | 1 | let expected = [ |
664 | 1 | "+----+----+----+----+----+----+", |
665 | 1 | "| a1 | b1 | c1 | a2 | b2 | c2 |", |
666 | 1 | "+----+----+----+----+----+----+", |
667 | 1 | "| 1 | 4 | 7 | 10 | 12 | 14 |", |
668 | 1 | "| 1 | 4 | 7 | 11 | 13 | 15 |", |
669 | 1 | "| 2 | 5 | 8 | 10 | 12 | 14 |", |
670 | 1 | "| 2 | 5 | 8 | 11 | 13 | 15 |", |
671 | 1 | "| 3 | 6 | 9 | 10 | 12 | 14 |", |
672 | 1 | "| 3 | 6 | 9 | 11 | 13 | 15 |", |
673 | 1 | "+----+----+----+----+----+----+", |
674 | 1 | ]; |
675 | 1 | |
676 | 1 | assert_batches_sorted_eq!(expected, &batches); |
677 | 1 | |
678 | 1 | Ok(()) |
679 | 1 | } |
680 | | |
681 | | #[tokio::test] |
682 | 1 | async fn test_overallocation() -> Result<()> { |
683 | 1 | let runtime = RuntimeEnvBuilder::new() |
684 | 1 | .with_memory_limit(100, 1.0) |
685 | 1 | .build_arc()?0 ; |
686 | 1 | let task_ctx = TaskContext::default().with_runtime(runtime); |
687 | 1 | let task_ctx = Arc::new(task_ctx); |
688 | 1 | |
689 | 1 | let left = build_table_scan_i32( |
690 | 1 | ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
691 | 1 | ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
692 | 1 | ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), |
693 | 1 | ); |
694 | 1 | let right = build_table_scan_i32( |
695 | 1 | ("a2", &vec![10, 11]), |
696 | 1 | ("b2", &vec![12, 13]), |
697 | 1 | ("c2", &vec![14, 15]), |
698 | 1 | ); |
699 | 1 | |
700 | 1 | let err = join_collect(left, right, task_ctx).await0 .unwrap_err(); |
701 | 1 | |
702 | 1 | assert_contains!( |
703 | 1 | err.to_string(), |
704 | 1 | "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec" |
705 | 1 | ); |
706 | 1 | |
707 | 1 | Ok(()) |
708 | 1 | } |
709 | | |
710 | | /// Returns the column names on the schema |
711 | 2 | fn columns(schema: &Schema) -> Vec<String> { |
712 | 12 | schema.fields().iter().map(|f| f.name().clone()).collect() |
713 | 2 | } |
714 | | } |