/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/joins/stream_join_utils.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! This file contains common subroutines for symmetric hash join |
19 | | //! related functionality, used both in join calculations and optimization rules. |
20 | | |
21 | | use std::collections::{HashMap, VecDeque}; |
22 | | use std::sync::Arc; |
23 | | |
24 | | use crate::joins::utils::{JoinFilter, JoinHashMapType}; |
25 | | use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; |
26 | | use crate::{metrics, ExecutionPlan}; |
27 | | |
28 | | use arrow::compute::concat_batches; |
29 | | use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; |
30 | | use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; |
31 | | use arrow_schema::{Schema, SchemaRef}; |
32 | | use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
33 | | use datafusion_common::{ |
34 | | arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, |
35 | | ScalarValue, |
36 | | }; |
37 | | use datafusion_expr::interval_arithmetic::Interval; |
38 | | use datafusion_physical_expr::expressions::Column; |
39 | | use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; |
40 | | use datafusion_physical_expr::utils::collect_columns; |
41 | | use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; |
42 | | |
43 | | use hashbrown::raw::RawTable; |
44 | | use hashbrown::HashSet; |
45 | | |
46 | | /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`. |
47 | | impl JoinHashMapType for PruningJoinHashMap { |
48 | | type NextType = VecDeque<u64>; |
49 | | |
50 | | // Extend with zero |
51 | 8.11k | fn extend_zero(&mut self, len: usize) { |
52 | 8.11k | self.next.resize(self.next.len() + len, 0) |
53 | 8.11k | } |
54 | | |
55 | | /// Get mutable references to the hash map and the next. |
56 | 8.11k | fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) { |
57 | 8.11k | (&mut self.map, &mut self.next) |
58 | 8.11k | } |
59 | | |
60 | | /// Get a reference to the hash map. |
61 | 6.88k | fn get_map(&self) -> &RawTable<(u64, u64)> { |
62 | 6.88k | &self.map |
63 | 6.88k | } |
64 | | |
65 | | /// Get a reference to the next. |
66 | 6.88k | fn get_list(&self) -> &Self::NextType { |
67 | 6.88k | &self.next |
68 | 6.88k | } |
69 | | } |
70 | | |
71 | | /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with |
72 | | /// the capability of pruning elements in an efficient manner. This structure |
73 | | /// is particularly useful for cases where it's necessary to remove elements |
74 | | /// from the map based on their buffer order. |
75 | | /// |
76 | | /// # Example |
77 | | /// |
78 | | /// ``` text |
79 | | /// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would |
80 | | /// handle the pruning scenario. |
81 | | /// |
82 | | /// Insert the pair (10,4) into the `PruningJoinHashMap`: |
83 | | /// map: |
84 | | /// ---------- |
85 | | /// | 10 | 5 | |
86 | | /// | 20 | 3 | |
87 | | /// ---------- |
88 | | /// list: |
89 | | /// --------------------- |
90 | | /// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1) |
91 | | /// --------------------- |
92 | | /// |
93 | | /// Now, let's prune 3 rows from `PruningJoinHashMap`: |
94 | | /// map: |
95 | | /// --------- |
96 | | /// | 1 | 5 | |
97 | | /// --------- |
98 | | /// list: |
99 | | /// --------- |
100 | | /// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0) |
101 | | /// --------- |
102 | | /// |
103 | | /// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since |
104 | | /// there are no values left for this key. |
105 | | /// ``` |
106 | | pub struct PruningJoinHashMap { |
107 | | /// Stores hash value to last row index |
108 | | pub map: RawTable<(u64, u64)>, |
109 | | /// Stores indices in chained list data structure |
110 | | pub next: VecDeque<u64>, |
111 | | } |
112 | | |
113 | | impl PruningJoinHashMap { |
114 | | /// Constructs a new `PruningJoinHashMap` with the given capacity. |
115 | | /// Both the map and the list are pre-allocated with the provided capacity. |
116 | | /// |
117 | | /// # Arguments |
118 | | /// * `capacity`: The initial capacity of the hash map. |
119 | | /// |
120 | | /// # Returns |
121 | | /// A new instance of `PruningJoinHashMap`. |
122 | 2.66k | pub(crate) fn with_capacity(capacity: usize) -> Self { |
123 | 2.66k | PruningJoinHashMap { |
124 | 2.66k | map: RawTable::with_capacity(capacity), |
125 | 2.66k | next: VecDeque::with_capacity(capacity), |
126 | 2.66k | } |
127 | 2.66k | } |
128 | | |
129 | | /// Shrinks the capacity of the hash map, if necessary, based on the |
130 | | /// provided scale factor. |
131 | | /// |
132 | | /// # Arguments |
133 | | /// * `scale_factor`: The scale factor that determines how conservative the |
134 | | /// shrinking strategy is. The capacity will be reduced by 1/`scale_factor` |
135 | | /// when necessary. |
136 | | /// |
137 | | /// # Note |
138 | | /// Increasing the scale factor results in less aggressive capacity shrinking, |
139 | | /// leading to potentially higher memory usage but fewer resizes. Conversely, |
140 | | /// decreasing the scale factor results in more aggressive capacity shrinking, |
141 | | /// potentially leading to lower memory usage but more frequent resizing. |
142 | 6.78k | pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) { |
143 | 6.78k | let capacity = self.map.capacity(); |
144 | 6.78k | |
145 | 6.78k | if capacity > scale_factor * self.map.len() { |
146 | 805 | let new_capacity = (capacity * (scale_factor - 1)) / scale_factor; |
147 | 805 | // Resize the map with the new capacity. |
148 | 805 | self.map.shrink_to(new_capacity, |(hash, _)| *hash500 ) |
149 | 5.97k | } |
150 | 6.78k | } |
151 | | |
152 | | /// Calculates the size of the `PruningJoinHashMap` in bytes. |
153 | | /// |
154 | | /// # Returns |
155 | | /// The size of the hash map in bytes. |
156 | 16.2k | pub(crate) fn size(&self) -> usize { |
157 | 16.2k | self.map.allocation_info().1.size() |
158 | 16.2k | + self.next.capacity() * std::mem::size_of::<u64>() |
159 | 16.2k | } |
160 | | |
161 | | /// Removes hash values from the map and the list based on the given pruning |
162 | | /// length and deleting offset. |
163 | | /// |
164 | | /// # Arguments |
165 | | /// * `prune_length`: The number of elements to remove from the list. |
166 | | /// * `deleting_offset`: The offset used to determine which hash values to remove from the map. |
167 | | /// |
168 | | /// # Returns |
169 | | /// A `Result` indicating whether the operation was successful. |
170 | 6.78k | pub(crate) fn prune_hash_values( |
171 | 6.78k | &mut self, |
172 | 6.78k | prune_length: usize, |
173 | 6.78k | deleting_offset: u64, |
174 | 6.78k | shrink_factor: usize, |
175 | 6.78k | ) { |
176 | 6.78k | // Remove elements from the list based on the pruning length. |
177 | 6.78k | self.next.drain(0..prune_length); |
178 | 6.78k | |
179 | 6.78k | // Calculate the keys that should be removed from the map. |
180 | 6.78k | let removable_keys = unsafe { |
181 | 6.78k | self.map |
182 | 6.78k | .iter() |
183 | 14.4k | .map(|bucket| bucket.as_ref()) |
184 | 14.4k | .filter_map(|(hash, tail_index)| { |
185 | 14.4k | (*tail_index < prune_length as u64 + deleting_offset).then_some(*hash) |
186 | 14.4k | }) |
187 | 6.78k | .collect::<Vec<_>>() |
188 | 6.78k | }; |
189 | 6.78k | |
190 | 6.78k | // Remove the keys from the map. |
191 | 6.78k | removable_keys.into_iter().for_each(|hash_value| { |
192 | 4.40k | self.map |
193 | 4.40k | .remove_entry(hash_value, |(hash, _)| hash_value == *hash); |
194 | 6.78k | }); |
195 | 6.78k | |
196 | 6.78k | // Shrink the map if necessary. |
197 | 6.78k | self.shrink_if_necessary(shrink_factor); |
198 | 6.78k | } |
199 | | } |
200 | | |
201 | 12.4k | fn check_filter_expr_contains_sort_information( |
202 | 12.4k | expr: &Arc<dyn PhysicalExpr>, |
203 | 12.4k | reference: &Arc<dyn PhysicalExpr>, |
204 | 12.4k | ) -> bool { |
205 | 12.4k | expr.eq(reference) |
206 | 10.1k | || expr |
207 | 10.1k | .children() |
208 | 10.1k | .iter() |
209 | 10.1k | .any(|e| check_filter_expr_contains_sort_information(e, reference)10.1k )2.22k |
210 | 12.4k | } |
211 | | |
212 | | /// Create a one to one mapping from main columns to filter columns using |
213 | | /// filter column indices. A column index looks like: |
214 | | /// ```text |
215 | | /// ColumnIndex { |
216 | | /// index: 0, // field index in main schema |
217 | | /// side: JoinSide::Left, // child side |
218 | | /// } |
219 | | /// ``` |
220 | 2.22k | pub fn map_origin_col_to_filter_col( |
221 | 2.22k | filter: &JoinFilter, |
222 | 2.22k | schema: &SchemaRef, |
223 | 2.22k | side: &JoinSide, |
224 | 2.22k | ) -> Result<HashMap<Column, Column>> { |
225 | 2.22k | let filter_schema = filter.schema(); |
226 | 2.22k | let mut col_to_col_map = HashMap::<Column, Column>::new(); |
227 | 4.59k | for (filter_schema_index, index) in filter.column_indices().iter().enumerate()2.22k { |
228 | 4.59k | if index.side.eq(side) { |
229 | | // Get the main field from column index: |
230 | 2.29k | let main_field = schema.field(index.index); |
231 | | // Create a column expression: |
232 | 2.29k | let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?0 ; |
233 | | // Since the order of by filter.column_indices() is the same with |
234 | | // that of intermediate schema fields, we can get the column directly. |
235 | 2.29k | let filter_field = filter_schema.field(filter_schema_index); |
236 | 2.29k | let filter_col = Column::new(filter_field.name(), filter_schema_index); |
237 | 2.29k | // Insert mapping: |
238 | 2.29k | col_to_col_map.insert(main_col, filter_col); |
239 | 2.29k | } |
240 | | } |
241 | 2.22k | Ok(col_to_col_map) |
242 | 2.22k | } |
243 | | |
244 | | /// This function analyzes [`PhysicalSortExpr`] graphs with respect to output orderings |
245 | | /// (sorting) properties. This is necessary since monotonically increasing and/or |
246 | | /// decreasing expressions are required when using join filter expressions for |
247 | | /// data pruning purposes. |
248 | | /// |
249 | | /// The method works as follows: |
250 | | /// 1. Maps the original columns to the filter columns using the [`map_origin_col_to_filter_col`] function. |
251 | | /// 2. Collects all columns in the sort expression using the [`collect_columns`] function. |
252 | | /// 3. Checks if all columns are included in the map we obtain in the first step. |
253 | | /// 4. If all columns are included, the sort expression is converted into a filter expression using |
254 | | /// the [`convert_filter_columns`] function. |
255 | | /// 5. Searches for the converted filter expression in the filter expression using the |
256 | | /// [`check_filter_expr_contains_sort_information`] function. |
257 | | /// 6. If an exact match is found, returns the converted filter expression as [`Some(Arc<dyn PhysicalExpr>)`]. |
258 | | /// 7. If all columns are not included or an exact match is not found, returns [`None`]. |
259 | | /// |
260 | | /// Examples: |
261 | | /// Consider the filter expression "a + b > c + 10 AND a + b < c + 100". |
262 | | /// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. |
263 | | /// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. |
264 | | /// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However, |
265 | | /// there is no exact match, so this expression does not indicate pruning. |
266 | 2.22k | pub fn convert_sort_expr_with_filter_schema( |
267 | 2.22k | side: &JoinSide, |
268 | 2.22k | filter: &JoinFilter, |
269 | 2.22k | schema: &SchemaRef, |
270 | 2.22k | sort_expr: &PhysicalSortExpr, |
271 | 2.22k | ) -> Result<Option<Arc<dyn PhysicalExpr>>> { |
272 | 2.22k | let column_map = map_origin_col_to_filter_col(filter, schema, side)?0 ; |
273 | 2.22k | let expr = Arc::clone(&sort_expr.expr); |
274 | 2.22k | // Get main schema columns: |
275 | 2.22k | let expr_columns = collect_columns(&expr); |
276 | 2.22k | // Calculation is possible with `column_map` since sort exprs belong to a child. |
277 | 2.22k | let all_columns_are_included = |
278 | 2.28k | expr_columns.iter().all(|col| column_map.contains_key(col)); |
279 | 2.22k | if all_columns_are_included { |
280 | | // Since we are sure that one to one column mapping includes all columns, we convert |
281 | | // the sort expression into a filter expression. |
282 | 2.22k | let converted_filter_expr = expr |
283 | 2.35k | .transform_up(|p| { |
284 | 2.35k | convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { |
285 | 2.35k | match transformed { |
286 | 2.28k | Some(transformed) => Transformed::yes(transformed), |
287 | 66 | None => Transformed::no(p), |
288 | | } |
289 | 2.35k | }) |
290 | 2.35k | }) |
291 | 2.22k | .data()?0 ; |
292 | | // Search the converted `PhysicalExpr` in filter expression; if an exact |
293 | | // match is found, use this sorted expression in graph traversals. |
294 | 2.22k | if check_filter_expr_contains_sort_information( |
295 | 2.22k | filter.expression(), |
296 | 2.22k | &converted_filter_expr, |
297 | 2.22k | ) { |
298 | 2.22k | return Ok(Some(converted_filter_expr)); |
299 | 1 | } |
300 | 2 | } |
301 | 3 | Ok(None) |
302 | 2.22k | } |
303 | | |
304 | | /// This function is used to build the filter expression based on the sort order of input columns. |
305 | | /// |
306 | | /// It first calls the [`convert_sort_expr_with_filter_schema`] method to determine if the sort |
307 | | /// order of columns can be used in the filter expression. If it returns a [`Some`] value, the |
308 | | /// method wraps the result in a [`SortedFilterExpr`] instance with the original sort expression and |
309 | | /// the converted filter expression. Otherwise, this function returns an error. |
310 | | /// |
311 | | /// The `SortedFilterExpr` instance contains information about the sort order of columns that can |
312 | | /// be used in the filter expression, which can be used to optimize the query execution process. |
313 | 2.22k | pub fn build_filter_input_order( |
314 | 2.22k | side: JoinSide, |
315 | 2.22k | filter: &JoinFilter, |
316 | 2.22k | schema: &SchemaRef, |
317 | 2.22k | order: &PhysicalSortExpr, |
318 | 2.22k | ) -> Result<Option<SortedFilterExpr>> { |
319 | 2.22k | let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?0 ; |
320 | 2.22k | opt_expr |
321 | 2.22k | .map(|filter_expr| { |
322 | 2.22k | SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema()) |
323 | 2.22k | }) |
324 | 2.22k | .transpose() |
325 | 2.22k | } |
326 | | |
327 | | /// Convert a physical expression into a filter expression using the given |
328 | | /// column mapping information. |
329 | 2.35k | fn convert_filter_columns( |
330 | 2.35k | input: &dyn PhysicalExpr, |
331 | 2.35k | column_map: &HashMap<Column, Column>, |
332 | 2.35k | ) -> Result<Option<Arc<dyn PhysicalExpr>>> { |
333 | | // Attempt to downcast the input expression to a Column type. |
334 | 2.35k | Ok(if let Some(col2.28k ) = input.as_any().downcast_ref::<Column>() { |
335 | | // If the downcast is successful, retrieve the corresponding filter column. |
336 | 2.28k | column_map.get(col).map(|c| Arc::new(c.clone()) as _) |
337 | | } else { |
338 | | // If the downcast fails, return the input expression as is. |
339 | 66 | None |
340 | | }) |
341 | 2.35k | } |
342 | | |
343 | | /// The [SortedFilterExpr] object represents a sorted filter expression. It |
344 | | /// contains the following information: The origin expression, the filter |
345 | | /// expression, an interval encapsulating expression bounds, and a stable |
346 | | /// index identifying the expression in the expression DAG. |
347 | | /// |
348 | | /// Physical schema of a [JoinFilter]'s intermediate batch combines two sides |
349 | | /// and uses new column names. In this process, a column exchange is done so |
350 | | /// we can utilize sorting information while traversing the filter expression |
351 | | /// DAG for interval calculations. When evaluating the inner buffer, we use |
352 | | /// `origin_sorted_expr`. |
353 | | #[derive(Debug, Clone)] |
354 | | pub struct SortedFilterExpr { |
355 | | /// Sorted expression from a join side (i.e. a child of the join) |
356 | | origin_sorted_expr: PhysicalSortExpr, |
357 | | /// Expression adjusted for filter schema. |
358 | | filter_expr: Arc<dyn PhysicalExpr>, |
359 | | /// Interval containing expression bounds |
360 | | interval: Interval, |
361 | | /// Node index in the expression DAG |
362 | | node_index: usize, |
363 | | } |
364 | | |
365 | | impl SortedFilterExpr { |
366 | | /// Constructor |
367 | 2.22k | pub fn try_new( |
368 | 2.22k | origin_sorted_expr: PhysicalSortExpr, |
369 | 2.22k | filter_expr: Arc<dyn PhysicalExpr>, |
370 | 2.22k | filter_schema: &Schema, |
371 | 2.22k | ) -> Result<Self> { |
372 | 2.22k | let dt = &filter_expr.data_type(filter_schema)?0 ; |
373 | | Ok(Self { |
374 | 2.22k | origin_sorted_expr, |
375 | 2.22k | filter_expr, |
376 | 2.22k | interval: Interval::make_unbounded(dt)?0 , |
377 | | node_index: 0, |
378 | | }) |
379 | 2.22k | } |
380 | | /// Get origin expr information |
381 | 26.5k | pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { |
382 | 26.5k | &self.origin_sorted_expr |
383 | 26.5k | } |
384 | | /// Get filter expr information |
385 | 2.21k | pub fn filter_expr(&self) -> &Arc<dyn PhysicalExpr> { |
386 | 2.21k | &self.filter_expr |
387 | 2.21k | } |
388 | | /// Get interval information |
389 | 20.8k | pub fn interval(&self) -> &Interval { |
390 | 20.8k | &self.interval |
391 | 20.8k | } |
392 | | /// Sets interval |
393 | 15.1k | pub fn set_interval(&mut self, interval: Interval) { |
394 | 15.1k | self.interval = interval; |
395 | 15.1k | } |
396 | | /// Node index in ExprIntervalGraph |
397 | 11.4k | pub fn node_index(&self) -> usize { |
398 | 11.4k | self.node_index |
399 | 11.4k | } |
400 | | /// Node index setter in ExprIntervalGraph |
401 | 2.21k | pub fn set_node_index(&mut self, node_index: usize) { |
402 | 2.21k | self.node_index = node_index; |
403 | 2.21k | } |
404 | | } |
405 | | |
406 | | /// Calculate the filter expression intervals. |
407 | | /// |
408 | | /// This function updates the `interval` field of each `SortedFilterExpr` based |
409 | | /// on the first or the last value of the expression in `build_input_buffer` |
410 | | /// and `probe_batch`. |
411 | | /// |
412 | | /// # Arguments |
413 | | /// |
414 | | /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. |
415 | | /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. |
416 | | /// * `probe_batch` - The `RecordBatch` on the probe side of the join. |
417 | | /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. |
418 | | /// |
419 | | /// ### Note |
420 | | /// ```text |
421 | | /// |
422 | | /// Interval arithmetic is used to calculate viable join ranges for build-side |
423 | | /// pruning. This is done by first creating an interval for join filter values in |
424 | | /// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the |
425 | | /// ordering (descending/ascending) of the filter expression. Here, FV denotes the |
426 | | /// first value on the build side. This range is then compared with the probe side |
427 | | /// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering |
428 | | /// (ascending/descending) of the probe side. Here, LV denotes the last value on |
429 | | /// the probe side. |
430 | | /// |
431 | | /// As a concrete example, consider the following query: |
432 | | /// |
433 | | /// SELECT * FROM left_table, right_table |
434 | | /// WHERE |
435 | | /// left_key = right_key AND |
436 | | /// a > b - 3 AND |
437 | | /// a < b + 10 |
438 | | /// |
439 | | /// where columns "a" and "b" come from tables "left_table" and "right_table", |
440 | | /// respectively. When a new `RecordBatch` arrives at the right side, the |
441 | | /// condition a > b - 3 will possibly indicate a prunable range for the left |
442 | | /// side. Conversely, when a new `RecordBatch` arrives at the left side, the |
443 | | /// condition a < b + 10 will possibly indicate prunability for the right side. |
444 | | /// Let’s inspect what happens when a new RecordBatch` arrives at the right |
445 | | /// side (i.e. when the left side is the build side): |
446 | | /// |
447 | | /// Build Probe |
448 | | /// +-------+ +-------+ |
449 | | /// | a | z | | b | y | |
450 | | /// |+--|--+| |+--|--+| |
451 | | /// | 1 | 2 | | 4 | 3 | |
452 | | /// |+--|--+| |+--|--+| |
453 | | /// | 3 | 1 | | 4 | 3 | |
454 | | /// |+--|--+| |+--|--+| |
455 | | /// | 5 | 7 | | 6 | 1 | |
456 | | /// |+--|--+| |+--|--+| |
457 | | /// | 7 | 1 | | 6 | 3 | |
458 | | /// +-------+ +-------+ |
459 | | /// |
460 | | /// In this case, the interval representing viable (i.e. joinable) values for |
461 | | /// column "a" is [1, ∞], and the interval representing possible future values |
462 | | /// for column "b" is [6, ∞]. With these intervals at hand, we next calculate |
463 | | /// intervals for the whole filter expression and propagate join constraint by |
464 | | /// traversing the expression graph. |
465 | | /// ``` |
466 | 6.78k | pub fn calculate_filter_expr_intervals( |
467 | 6.78k | build_input_buffer: &RecordBatch, |
468 | 6.78k | build_sorted_filter_expr: &mut SortedFilterExpr, |
469 | 6.78k | probe_batch: &RecordBatch, |
470 | 6.78k | probe_sorted_filter_expr: &mut SortedFilterExpr, |
471 | 6.78k | ) -> Result<()> { |
472 | 6.78k | // If either build or probe side has no data, return early: |
473 | 6.78k | if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 05.72k { |
474 | 1.05k | return Ok(()); |
475 | 5.72k | } |
476 | 5.72k | // Calculate the interval for the build side filter expression (if present): |
477 | 5.72k | update_filter_expr_interval( |
478 | 5.72k | &build_input_buffer.slice(0, 1), |
479 | 5.72k | build_sorted_filter_expr, |
480 | 5.72k | )?0 ; |
481 | | // Calculate the interval for the probe side filter expression (if present): |
482 | 5.72k | update_filter_expr_interval( |
483 | 5.72k | &probe_batch.slice(probe_batch.num_rows() - 1, 1), |
484 | 5.72k | probe_sorted_filter_expr, |
485 | 5.72k | ) |
486 | 6.78k | } |
487 | | |
488 | | /// This is a subroutine of the function [`calculate_filter_expr_intervals`]. |
489 | | /// It constructs the current interval using the given `batch` and updates |
490 | | /// the filter expression (i.e. `sorted_expr`) with this interval. |
491 | 11.4k | pub fn update_filter_expr_interval( |
492 | 11.4k | batch: &RecordBatch, |
493 | 11.4k | sorted_expr: &mut SortedFilterExpr, |
494 | 11.4k | ) -> Result<()> { |
495 | | // Evaluate the filter expression and convert the result to an array: |
496 | 11.4k | let array = sorted_expr |
497 | 11.4k | .origin_sorted_expr() |
498 | 11.4k | .expr |
499 | 11.4k | .evaluate(batch)?0 |
500 | 11.4k | .into_array(1)?0 ; |
501 | | // Convert the array to a ScalarValue: |
502 | 11.4k | let value = ScalarValue::try_from_array(&array, 0)?0 ; |
503 | | // Create a ScalarValue representing positive or negative infinity for the same data type: |
504 | 11.4k | let inf = ScalarValue::try_from(value.data_type())?0 ; |
505 | | // Update the interval with lower and upper bounds based on the sort option: |
506 | 11.4k | let interval = if sorted_expr.origin_sorted_expr().options.descending { |
507 | 1.86k | Interval::try_new(inf, value)?0 |
508 | | } else { |
509 | 9.59k | Interval::try_new(value, inf)?0 |
510 | | }; |
511 | | // Set the calculated interval for the sorted filter expression: |
512 | 11.4k | sorted_expr.set_interval(interval); |
513 | 11.4k | Ok(()) |
514 | 11.4k | } |
515 | | |
516 | | /// Get the anti join indices from the visited hash set. |
517 | | /// |
518 | | /// This method returns the indices from the original input that were not present in the visited hash set. |
519 | | /// |
520 | | /// # Arguments |
521 | | /// |
522 | | /// * `prune_length` - The length of the pruned record batch. |
523 | | /// * `deleted_offset` - The offset to the indices. |
524 | | /// * `visited_rows` - The hash set of visited indices. |
525 | | /// |
526 | | /// # Returns |
527 | | /// |
528 | | /// A `PrimitiveArray` of the anti join indices. |
529 | 2.11k | pub fn get_pruning_anti_indices<T: ArrowPrimitiveType>( |
530 | 2.11k | prune_length: usize, |
531 | 2.11k | deleted_offset: usize, |
532 | 2.11k | visited_rows: &HashSet<usize>, |
533 | 2.11k | ) -> PrimitiveArray<T> |
534 | 2.11k | where |
535 | 2.11k | NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>, |
536 | 2.11k | { |
537 | 2.11k | let mut bitmap = BooleanBufferBuilder::new(prune_length); |
538 | 2.11k | bitmap.append_n(prune_length, false); |
539 | | // mark the indices as true if they are present in the visited hash set |
540 | 7.68k | for v in 0..prune_length2.11k { |
541 | 7.68k | let row = v + deleted_offset; |
542 | 7.68k | bitmap.set_bit(v, visited_rows.contains(&row)); |
543 | 7.68k | } |
544 | | // get the anti index |
545 | 2.11k | (0..prune_length) |
546 | 7.68k | .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) |
547 | 2.11k | .collect() |
548 | 2.11k | } |
549 | | |
550 | | /// This method creates a boolean buffer from the visited rows hash set |
551 | | /// and the indices of the pruned record batch slice. |
552 | | /// |
553 | | /// It gets the indices from the original input that were present in the visited hash set. |
554 | | /// |
555 | | /// # Arguments |
556 | | /// |
557 | | /// * `prune_length` - The length of the pruned record batch. |
558 | | /// * `deleted_offset` - The offset to the indices. |
559 | | /// * `visited_rows` - The hash set of visited indices. |
560 | | /// |
561 | | /// # Returns |
562 | | /// |
563 | | /// A [PrimitiveArray] of the specified type T, containing the semi indices. |
564 | 683 | pub fn get_pruning_semi_indices<T: ArrowPrimitiveType>( |
565 | 683 | prune_length: usize, |
566 | 683 | deleted_offset: usize, |
567 | 683 | visited_rows: &HashSet<usize>, |
568 | 683 | ) -> PrimitiveArray<T> |
569 | 683 | where |
570 | 683 | NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>, |
571 | 683 | { |
572 | 683 | let mut bitmap = BooleanBufferBuilder::new(prune_length); |
573 | 683 | bitmap.append_n(prune_length, false); |
574 | 683 | // mark the indices as true if they are present in the visited hash set |
575 | 2.46k | (0..prune_length).for_each(|v| { |
576 | 2.46k | let row = &(v + deleted_offset); |
577 | 2.46k | bitmap.set_bit(v, visited_rows.contains(row)); |
578 | 2.46k | }); |
579 | 683 | // get the semi index |
580 | 683 | (0..prune_length) |
581 | 2.46k | .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) |
582 | 683 | .collect() |
583 | 683 | } |
584 | | |
585 | 9.44k | pub fn combine_two_batches( |
586 | 9.44k | output_schema: &SchemaRef, |
587 | 9.44k | left_batch: Option<RecordBatch>, |
588 | 9.44k | right_batch: Option<RecordBatch>, |
589 | 9.44k | ) -> Result<Option<RecordBatch>> { |
590 | 9.44k | match (left_batch, right_batch) { |
591 | 1.56k | (Some(batch1.09k ), None) | (None, Some(batch)) => { |
592 | | // If only one of the batches are present, return it: |
593 | 2.66k | Ok(Some(batch)) |
594 | | } |
595 | 96 | (Some(left_batch), Some(right_batch)) => { |
596 | 96 | // If both batches are present, concatenate them: |
597 | 96 | concat_batches(output_schema, &[left_batch, right_batch]) |
598 | 96 | .map_err(|e| arrow_datafusion_err!(e)0 ) |
599 | 96 | .map(Some) |
600 | | } |
601 | | (None, None) => { |
602 | | // If neither is present, return an empty batch: |
603 | 6.69k | Ok(None) |
604 | | } |
605 | | } |
606 | 9.44k | } |
607 | | |
608 | | /// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. |
609 | | /// This function will insert the indices (offset by `offset`) into the `visited` hash set. |
610 | | /// |
611 | | /// # Arguments |
612 | | /// |
613 | | /// * `visited` - A hash set to store the visited indices. |
614 | | /// * `offset` - An offset to the indices in the `PrimitiveArray`. |
615 | | /// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. |
616 | | /// |
617 | 6.97k | pub fn record_visited_indices<T: ArrowPrimitiveType>( |
618 | 6.97k | visited: &mut HashSet<usize>, |
619 | 6.97k | offset: usize, |
620 | 6.97k | indices: &PrimitiveArray<T>, |
621 | 6.97k | ) { |
622 | 6.97k | for i6.71k in indices.values() { |
623 | 6.71k | visited.insert(i.as_usize() + offset); |
624 | 6.71k | } |
625 | 6.97k | } |
626 | | |
627 | | #[derive(Debug)] |
628 | | pub struct StreamJoinSideMetrics { |
629 | | /// Number of batches consumed by this operator |
630 | | pub(crate) input_batches: metrics::Count, |
631 | | /// Number of rows consumed by this operator |
632 | | pub(crate) input_rows: metrics::Count, |
633 | | } |
634 | | |
635 | | /// Metrics for HashJoinExec |
636 | | #[derive(Debug)] |
637 | | pub struct StreamJoinMetrics { |
638 | | /// Number of left batches/rows consumed by this operator |
639 | | pub(crate) left: StreamJoinSideMetrics, |
640 | | /// Number of right batches/rows consumed by this operator |
641 | | pub(crate) right: StreamJoinSideMetrics, |
642 | | /// Memory used by sides in bytes |
643 | | pub(crate) stream_memory_usage: metrics::Gauge, |
644 | | /// Number of batches produced by this operator |
645 | | pub(crate) output_batches: metrics::Count, |
646 | | /// Number of rows produced by this operator |
647 | | pub(crate) output_rows: metrics::Count, |
648 | | } |
649 | | |
650 | | impl StreamJoinMetrics { |
651 | 1.33k | pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { |
652 | 1.33k | let input_batches = |
653 | 1.33k | MetricBuilder::new(metrics).counter("input_batches", partition); |
654 | 1.33k | let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); |
655 | 1.33k | let left = StreamJoinSideMetrics { |
656 | 1.33k | input_batches, |
657 | 1.33k | input_rows, |
658 | 1.33k | }; |
659 | 1.33k | |
660 | 1.33k | let input_batches = |
661 | 1.33k | MetricBuilder::new(metrics).counter("input_batches", partition); |
662 | 1.33k | let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); |
663 | 1.33k | let right = StreamJoinSideMetrics { |
664 | 1.33k | input_batches, |
665 | 1.33k | input_rows, |
666 | 1.33k | }; |
667 | 1.33k | |
668 | 1.33k | let stream_memory_usage = |
669 | 1.33k | MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); |
670 | 1.33k | |
671 | 1.33k | let output_batches = |
672 | 1.33k | MetricBuilder::new(metrics).counter("output_batches", partition); |
673 | 1.33k | |
674 | 1.33k | let output_rows = MetricBuilder::new(metrics).output_rows(partition); |
675 | 1.33k | |
676 | 1.33k | Self { |
677 | 1.33k | left, |
678 | 1.33k | right, |
679 | 1.33k | output_batches, |
680 | 1.33k | stream_memory_usage, |
681 | 1.33k | output_rows, |
682 | 1.33k | } |
683 | 1.33k | } |
684 | | } |
685 | | |
686 | | /// Updates sorted filter expressions with corresponding node indices from the |
687 | | /// expression interval graph. |
688 | | /// |
689 | | /// This function iterates through the provided sorted filter expressions, |
690 | | /// gathers the corresponding node indices from the expression interval graph, |
691 | | /// and then updates the sorted expressions with these indices. It ensures |
692 | | /// that these sorted expressions are aligned with the structure of the graph. |
693 | 1.10k | fn update_sorted_exprs_with_node_indices( |
694 | 1.10k | graph: &mut ExprIntervalGraph, |
695 | 1.10k | sorted_exprs: &mut [SortedFilterExpr], |
696 | 1.10k | ) { |
697 | 1.10k | // Extract filter expressions from the sorted expressions: |
698 | 1.10k | let filter_exprs = sorted_exprs |
699 | 1.10k | .iter() |
700 | 2.21k | .map(|expr| Arc::clone(expr.filter_expr())) |
701 | 1.10k | .collect::<Vec<_>>(); |
702 | 1.10k | |
703 | 1.10k | // Gather corresponding node indices for the extracted filter expressions from the graph: |
704 | 1.10k | let child_node_indices = graph.gather_node_indices(&filter_exprs); |
705 | | |
706 | | // Iterate through the sorted expressions and the gathered node indices: |
707 | 2.21k | for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices)1.10k { |
708 | 2.21k | // Update each sorted expression with the corresponding node index: |
709 | 2.21k | sorted_expr.set_node_index(index); |
710 | 2.21k | } |
711 | 1.10k | } |
712 | | |
713 | | /// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. |
714 | | /// |
715 | | /// # Arguments |
716 | | /// |
717 | | /// * `filter` - The join filter to base the sorting on. |
718 | | /// * `left` - The left execution plan. |
719 | | /// * `right` - The right execution plan. |
720 | | /// * `left_sort_exprs` - The expressions to sort on the left side. |
721 | | /// * `right_sort_exprs` - The expressions to sort on the right side. |
722 | | /// |
723 | | /// # Returns |
724 | | /// |
725 | | /// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph. |
726 | 1.10k | pub fn prepare_sorted_exprs( |
727 | 1.10k | filter: &JoinFilter, |
728 | 1.10k | left: &Arc<dyn ExecutionPlan>, |
729 | 1.10k | right: &Arc<dyn ExecutionPlan>, |
730 | 1.10k | left_sort_exprs: &[PhysicalSortExpr], |
731 | 1.10k | right_sort_exprs: &[PhysicalSortExpr], |
732 | 1.10k | ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { |
733 | 1.10k | // Build the filter order for the left side |
734 | 1.10k | let err = || plan_datafusion_err!("Filter does not include the child order")0 ; |
735 | | |
736 | 1.10k | let left_temp_sorted_filter_expr = build_filter_input_order( |
737 | 1.10k | JoinSide::Left, |
738 | 1.10k | filter, |
739 | 1.10k | &left.schema(), |
740 | 1.10k | &left_sort_exprs[0], |
741 | 1.10k | )?0 |
742 | 1.10k | .ok_or_else(err)?0 ; |
743 | | |
744 | | // Build the filter order for the right side |
745 | 1.10k | let right_temp_sorted_filter_expr = build_filter_input_order( |
746 | 1.10k | JoinSide::Right, |
747 | 1.10k | filter, |
748 | 1.10k | &right.schema(), |
749 | 1.10k | &right_sort_exprs[0], |
750 | 1.10k | )?0 |
751 | 1.10k | .ok_or_else(err)?0 ; |
752 | | |
753 | | // Collect the sorted expressions |
754 | 1.10k | let mut sorted_exprs = |
755 | 1.10k | vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr]; |
756 | | |
757 | | // Build the expression interval graph |
758 | 1.10k | let mut graph = |
759 | 1.10k | ExprIntervalGraph::try_new(Arc::clone(filter.expression()), filter.schema())?0 ; |
760 | | |
761 | | // Update sorted expressions with node indices |
762 | 1.10k | update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); |
763 | 1.10k | |
764 | 1.10k | // Swap and remove to get the final sorted filter expressions |
765 | 1.10k | let right_sorted_filter_expr = sorted_exprs.swap_remove(1); |
766 | 1.10k | let left_sorted_filter_expr = sorted_exprs.swap_remove(0); |
767 | 1.10k | |
768 | 1.10k | Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) |
769 | 1.10k | } |
770 | | |
771 | | #[cfg(test)] |
772 | | pub mod tests { |
773 | | |
774 | | use super::*; |
775 | | use crate::{joins::test_utils::complicated_filter, joins::utils::ColumnIndex}; |
776 | | |
777 | | use arrow::compute::SortOptions; |
778 | | use arrow::datatypes::{DataType, Field}; |
779 | | use datafusion_expr::Operator; |
780 | | use datafusion_physical_expr::expressions::{binary, cast, col}; |
781 | | |
782 | | #[test] |
783 | 1 | fn test_column_exchange() -> Result<()> { |
784 | 1 | let left_child_schema = |
785 | 1 | Schema::new(vec![Field::new("left_1", DataType::Int32, true)]); |
786 | | // Sorting information for the left side: |
787 | 1 | let left_child_sort_expr = PhysicalSortExpr { |
788 | 1 | expr: col("left_1", &left_child_schema)?0 , |
789 | 1 | options: SortOptions::default(), |
790 | 1 | }; |
791 | 1 | |
792 | 1 | let right_child_schema = Schema::new(vec![ |
793 | 1 | Field::new("right_1", DataType::Int32, true), |
794 | 1 | Field::new("right_2", DataType::Int32, true), |
795 | 1 | ]); |
796 | | // Sorting information for the right side: |
797 | 1 | let right_child_sort_expr = PhysicalSortExpr { |
798 | | expr: binary( |
799 | 1 | col("right_1", &right_child_schema)?0 , |
800 | 1 | Operator::Plus, |
801 | 1 | col("right_2", &right_child_schema)?0 , |
802 | 1 | &right_child_schema, |
803 | 0 | )?, |
804 | 1 | options: SortOptions::default(), |
805 | 1 | }; |
806 | 1 | |
807 | 1 | let intermediate_schema = Schema::new(vec![ |
808 | 1 | Field::new("filter_1", DataType::Int32, true), |
809 | 1 | Field::new("filter_2", DataType::Int32, true), |
810 | 1 | Field::new("filter_3", DataType::Int32, true), |
811 | 1 | ]); |
812 | | // Our filter expression is: left_1 > right_1 + right_2. |
813 | 1 | let filter_left = col("filter_1", &intermediate_schema)?0 ; |
814 | 1 | let filter_right = binary( |
815 | 1 | col("filter_2", &intermediate_schema)?0 , |
816 | 1 | Operator::Plus, |
817 | 1 | col("filter_3", &intermediate_schema)?0 , |
818 | 1 | &intermediate_schema, |
819 | 0 | )?; |
820 | 1 | let filter_expr = binary( |
821 | 1 | Arc::clone(&filter_left), |
822 | 1 | Operator::Gt, |
823 | 1 | Arc::clone(&filter_right), |
824 | 1 | &intermediate_schema, |
825 | 1 | )?0 ; |
826 | 1 | let column_indices = vec![ |
827 | 1 | ColumnIndex { |
828 | 1 | index: 0, |
829 | 1 | side: JoinSide::Left, |
830 | 1 | }, |
831 | 1 | ColumnIndex { |
832 | 1 | index: 0, |
833 | 1 | side: JoinSide::Right, |
834 | 1 | }, |
835 | 1 | ColumnIndex { |
836 | 1 | index: 1, |
837 | 1 | side: JoinSide::Right, |
838 | 1 | }, |
839 | 1 | ]; |
840 | 1 | let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); |
841 | | |
842 | 1 | let left_sort_filter_expr = build_filter_input_order( |
843 | 1 | JoinSide::Left, |
844 | 1 | &filter, |
845 | 1 | &Arc::new(left_child_schema), |
846 | 1 | &left_child_sort_expr, |
847 | 1 | )?0 |
848 | 1 | .unwrap(); |
849 | 1 | assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr())); |
850 | | |
851 | 1 | let right_sort_filter_expr = build_filter_input_order( |
852 | 1 | JoinSide::Right, |
853 | 1 | &filter, |
854 | 1 | &Arc::new(right_child_schema), |
855 | 1 | &right_child_sort_expr, |
856 | 1 | )?0 |
857 | 1 | .unwrap(); |
858 | 1 | assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr())); |
859 | | |
860 | | // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`: |
861 | 1 | assert!(filter_left.eq(left_sort_filter_expr.filter_expr())); |
862 | | // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`: |
863 | 1 | assert!(filter_right.eq(right_sort_filter_expr.filter_expr())); |
864 | 1 | Ok(()) |
865 | 1 | } |
866 | | |
867 | | #[test] |
868 | 1 | fn test_column_collector() -> Result<()> { |
869 | 1 | let schema = Schema::new(vec![ |
870 | 1 | Field::new("0", DataType::Int32, true), |
871 | 1 | Field::new("1", DataType::Int32, true), |
872 | 1 | Field::new("2", DataType::Int32, true), |
873 | 1 | ]); |
874 | 1 | let filter_expr = complicated_filter(&schema)?0 ; |
875 | 1 | let columns = collect_columns(&filter_expr); |
876 | 1 | assert_eq!(columns.len(), 3); |
877 | 1 | Ok(()) |
878 | 1 | } |
879 | | |
880 | | #[test] |
881 | 1 | fn find_expr_inside_expr() -> Result<()> { |
882 | 1 | let schema = Schema::new(vec![ |
883 | 1 | Field::new("0", DataType::Int32, true), |
884 | 1 | Field::new("1", DataType::Int32, true), |
885 | 1 | Field::new("2", DataType::Int32, true), |
886 | 1 | ]); |
887 | 1 | let filter_expr = complicated_filter(&schema)?0 ; |
888 | | |
889 | 1 | let expr_1 = Arc::new(Column::new("gnz", 0)) as _; |
890 | 1 | assert!(!check_filter_expr_contains_sort_information( |
891 | 1 | &filter_expr, |
892 | 1 | &expr_1 |
893 | 1 | )); |
894 | | |
895 | 1 | let expr_2 = col("1", &schema)?0 as _; |
896 | | |
897 | 1 | assert!(check_filter_expr_contains_sort_information( |
898 | 1 | &filter_expr, |
899 | 1 | &expr_2 |
900 | 1 | )); |
901 | | |
902 | 1 | let expr_3 = cast( |
903 | | binary( |
904 | 1 | col("0", &schema)?0 , |
905 | 1 | Operator::Plus, |
906 | 1 | col("1", &schema)?0 , |
907 | 1 | &schema, |
908 | 0 | )?, |
909 | 1 | &schema, |
910 | 1 | DataType::Int64, |
911 | 0 | )?; |
912 | | |
913 | 1 | assert!(check_filter_expr_contains_sort_information( |
914 | 1 | &filter_expr, |
915 | 1 | &expr_3 |
916 | 1 | )); |
917 | | |
918 | 1 | let expr_4 = Arc::new(Column::new("1", 42)) as _; |
919 | 1 | |
920 | 1 | assert!(!check_filter_expr_contains_sort_information( |
921 | 1 | &filter_expr, |
922 | 1 | &expr_4, |
923 | 1 | )); |
924 | 1 | Ok(()) |
925 | 1 | } |
926 | | |
927 | | #[test] |
928 | 1 | fn build_sorted_expr() -> Result<()> { |
929 | 1 | let left_schema = Schema::new(vec![ |
930 | 1 | Field::new("la1", DataType::Int32, false), |
931 | 1 | Field::new("lb1", DataType::Int32, false), |
932 | 1 | Field::new("lc1", DataType::Int32, false), |
933 | 1 | Field::new("lt1", DataType::Int32, false), |
934 | 1 | Field::new("la2", DataType::Int32, false), |
935 | 1 | Field::new("la1_des", DataType::Int32, false), |
936 | 1 | ]); |
937 | 1 | |
938 | 1 | let right_schema = Schema::new(vec![ |
939 | 1 | Field::new("ra1", DataType::Int32, false), |
940 | 1 | Field::new("rb1", DataType::Int32, false), |
941 | 1 | Field::new("rc1", DataType::Int32, false), |
942 | 1 | Field::new("rt1", DataType::Int32, false), |
943 | 1 | Field::new("ra2", DataType::Int32, false), |
944 | 1 | Field::new("ra1_des", DataType::Int32, false), |
945 | 1 | ]); |
946 | 1 | |
947 | 1 | let intermediate_schema = Schema::new(vec![ |
948 | 1 | Field::new("0", DataType::Int32, true), |
949 | 1 | Field::new("1", DataType::Int32, true), |
950 | 1 | Field::new("2", DataType::Int32, true), |
951 | 1 | ]); |
952 | 1 | let filter_expr = complicated_filter(&intermediate_schema)?0 ; |
953 | 1 | let column_indices = vec![ |
954 | 1 | ColumnIndex { |
955 | 1 | index: 0, |
956 | 1 | side: JoinSide::Left, |
957 | 1 | }, |
958 | 1 | ColumnIndex { |
959 | 1 | index: 4, |
960 | 1 | side: JoinSide::Left, |
961 | 1 | }, |
962 | 1 | ColumnIndex { |
963 | 1 | index: 0, |
964 | 1 | side: JoinSide::Right, |
965 | 1 | }, |
966 | 1 | ]; |
967 | 1 | let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); |
968 | 1 | |
969 | 1 | let left_schema = Arc::new(left_schema); |
970 | 1 | let right_schema = Arc::new(right_schema); |
971 | 1 | |
972 | 1 | assert!(build_filter_input_order( |
973 | 1 | JoinSide::Left, |
974 | 1 | &filter, |
975 | 1 | &left_schema, |
976 | 1 | &PhysicalSortExpr { |
977 | 1 | expr: col("la1", left_schema.as_ref())?0 , |
978 | 1 | options: SortOptions::default(), |
979 | | } |
980 | 0 | )? |
981 | 1 | .is_some()); |
982 | 1 | assert!(build_filter_input_order( |
983 | 1 | JoinSide::Left, |
984 | 1 | &filter, |
985 | 1 | &left_schema, |
986 | 1 | &PhysicalSortExpr { |
987 | 1 | expr: col("lt1", left_schema.as_ref())?0 , |
988 | 1 | options: SortOptions::default(), |
989 | | } |
990 | 0 | )? |
991 | 1 | .is_none()); |
992 | 1 | assert!(build_filter_input_order( |
993 | 1 | JoinSide::Right, |
994 | 1 | &filter, |
995 | 1 | &right_schema, |
996 | 1 | &PhysicalSortExpr { |
997 | 1 | expr: col("ra1", right_schema.as_ref())?0 , |
998 | 1 | options: SortOptions::default(), |
999 | | } |
1000 | 0 | )? |
1001 | 1 | .is_some()); |
1002 | 1 | assert!(build_filter_input_order( |
1003 | 1 | JoinSide::Right, |
1004 | 1 | &filter, |
1005 | 1 | &right_schema, |
1006 | 1 | &PhysicalSortExpr { |
1007 | 1 | expr: col("rb1", right_schema.as_ref())?0 , |
1008 | 1 | options: SortOptions::default(), |
1009 | | } |
1010 | 0 | )? |
1011 | 1 | .is_none()); |
1012 | | |
1013 | 1 | Ok(()) |
1014 | 1 | } |
1015 | | |
1016 | | // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b". |
1017 | | #[test] |
1018 | 1 | fn sorted_filter_expr_build() -> Result<()> { |
1019 | 1 | let intermediate_schema = Schema::new(vec![ |
1020 | 1 | Field::new("0", DataType::Int32, true), |
1021 | 1 | Field::new("1", DataType::Int32, true), |
1022 | 1 | ]); |
1023 | 1 | let filter_expr = binary( |
1024 | 1 | col("0", &intermediate_schema)?0 , |
1025 | 1 | Operator::Minus, |
1026 | 1 | col("1", &intermediate_schema)?0 , |
1027 | 1 | &intermediate_schema, |
1028 | 0 | )?; |
1029 | 1 | let column_indices = vec![ |
1030 | 1 | ColumnIndex { |
1031 | 1 | index: 0, |
1032 | 1 | side: JoinSide::Left, |
1033 | 1 | }, |
1034 | 1 | ColumnIndex { |
1035 | 1 | index: 1, |
1036 | 1 | side: JoinSide::Left, |
1037 | 1 | }, |
1038 | 1 | ]; |
1039 | 1 | let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); |
1040 | 1 | |
1041 | 1 | let schema = Schema::new(vec![ |
1042 | 1 | Field::new("a", DataType::Int32, false), |
1043 | 1 | Field::new("b", DataType::Int32, false), |
1044 | 1 | ]); |
1045 | | |
1046 | 1 | let sorted = PhysicalSortExpr { |
1047 | | expr: binary( |
1048 | 1 | col("a", &schema)?0 , |
1049 | 1 | Operator::Plus, |
1050 | 1 | col("b", &schema)?0 , |
1051 | 1 | &schema, |
1052 | 0 | )?, |
1053 | 1 | options: SortOptions::default(), |
1054 | | }; |
1055 | | |
1056 | 1 | let res = convert_sort_expr_with_filter_schema( |
1057 | 1 | &JoinSide::Left, |
1058 | 1 | &filter, |
1059 | 1 | &Arc::new(schema), |
1060 | 1 | &sorted, |
1061 | 1 | )?0 ; |
1062 | 1 | assert!(res.is_none()); |
1063 | 1 | Ok(()) |
1064 | 1 | } |
1065 | | |
1066 | | #[test] |
1067 | 1 | fn test_shrink_if_necessary() { |
1068 | 1 | let scale_factor = 4; |
1069 | 1 | let mut join_hash_map = PruningJoinHashMap::with_capacity(100); |
1070 | 1 | let data_size = 2000; |
1071 | 1 | let deleted_part = 3 * data_size / 4; |
1072 | | // Add elements to the JoinHashMap |
1073 | 2.00k | for hash_value in 0..data_size1 { |
1074 | 2.00k | join_hash_map.map.insert( |
1075 | 2.00k | hash_value, |
1076 | 2.00k | (hash_value, hash_value), |
1077 | 3.47k | |(hash, _)| *hash, |
1078 | 2.00k | ); |
1079 | 2.00k | } |
1080 | | |
1081 | 1 | assert_eq!(join_hash_map.map.len(), data_size as usize); |
1082 | 1 | assert!(join_hash_map.map.capacity() >= data_size as usize); |
1083 | | |
1084 | | // Remove some elements from the JoinHashMap |
1085 | 1.50k | for hash_value in 0..deleted_part1 { |
1086 | 1.50k | join_hash_map |
1087 | 1.50k | .map |
1088 | 1.50k | .remove_entry(hash_value, |(hash, _)| hash_value == *hash); |
1089 | 1.50k | } |
1090 | | |
1091 | 1 | assert_eq!(join_hash_map.map.len(), (data_size - deleted_part) as usize); |
1092 | | |
1093 | | // Old capacity |
1094 | 1 | let old_capacity = join_hash_map.map.capacity(); |
1095 | 1 | |
1096 | 1 | // Test shrink_if_necessary |
1097 | 1 | join_hash_map.shrink_if_necessary(scale_factor); |
1098 | 1 | |
1099 | 1 | // The capacity should be reduced by the scale factor |
1100 | 1 | let new_expected_capacity = |
1101 | 1 | join_hash_map.map.capacity() * (scale_factor - 1) / scale_factor; |
1102 | 1 | assert!(join_hash_map.map.capacity() >= new_expected_capacity); |
1103 | 1 | assert!(join_hash_map.map.capacity() <= old_capacity); |
1104 | 1 | } |
1105 | | } |