/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/array_agg.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] |
19 | | |
20 | | use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, StructArray}; |
21 | | use arrow::datatypes::DataType; |
22 | | |
23 | | use arrow_schema::{Field, Fields}; |
24 | | use datafusion_common::cast::as_list_array; |
25 | | use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; |
26 | | use datafusion_common::{exec_err, ScalarValue}; |
27 | | use datafusion_common::{internal_err, Result}; |
28 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
29 | | use datafusion_expr::utils::format_state_name; |
30 | | use datafusion_expr::AggregateUDFImpl; |
31 | | use datafusion_expr::{Accumulator, Signature, Volatility}; |
32 | | use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; |
33 | | use datafusion_functions_aggregate_common::utils::ordering_fields; |
34 | | use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; |
35 | | use std::collections::{HashSet, VecDeque}; |
36 | | use std::sync::Arc; |
37 | | |
38 | | make_udaf_expr_and_func!( |
39 | | ArrayAgg, |
40 | | array_agg, |
41 | | expression, |
42 | | "input values, including nulls, concatenated into an array", |
43 | | array_agg_udaf |
44 | | ); |
45 | | |
46 | | #[derive(Debug)] |
47 | | /// ARRAY_AGG aggregate expression |
48 | | pub struct ArrayAgg { |
49 | | signature: Signature, |
50 | | } |
51 | | |
52 | | impl Default for ArrayAgg { |
53 | 1 | fn default() -> Self { |
54 | 1 | Self { |
55 | 1 | signature: Signature::any(1, Volatility::Immutable), |
56 | 1 | } |
57 | 1 | } |
58 | | } |
59 | | |
60 | | impl AggregateUDFImpl for ArrayAgg { |
61 | 0 | fn as_any(&self) -> &dyn std::any::Any { |
62 | 0 | self |
63 | 0 | } |
64 | | |
65 | 19 | fn name(&self) -> &str { |
66 | 19 | "array_agg" |
67 | 19 | } |
68 | | |
69 | 0 | fn aliases(&self) -> &[String] { |
70 | 0 | &[] |
71 | 0 | } |
72 | | |
73 | 7 | fn signature(&self) -> &Signature { |
74 | 7 | &self.signature |
75 | 7 | } |
76 | | |
77 | 7 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
78 | 7 | Ok(DataType::List(Arc::new(Field::new( |
79 | 7 | "item", |
80 | 7 | arg_types[0].clone(), |
81 | 7 | true, |
82 | 7 | )))) |
83 | 7 | } |
84 | | |
85 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
86 | 0 | if args.is_distinct { |
87 | 0 | return Ok(vec![Field::new_list( |
88 | 0 | format_state_name(args.name, "distinct_array_agg"), |
89 | 0 | // See COMMENTS.md to understand why nullable is set to true |
90 | 0 | Field::new("item", args.input_types[0].clone(), true), |
91 | 0 | true, |
92 | 0 | )]); |
93 | 0 | } |
94 | 0 |
|
95 | 0 | let mut fields = vec![Field::new_list( |
96 | 0 | format_state_name(args.name, "array_agg"), |
97 | 0 | // See COMMENTS.md to understand why nullable is set to true |
98 | 0 | Field::new("item", args.input_types[0].clone(), true), |
99 | 0 | true, |
100 | 0 | )]; |
101 | 0 |
|
102 | 0 | if args.ordering_fields.is_empty() { |
103 | 0 | return Ok(fields); |
104 | 0 | } |
105 | 0 |
|
106 | 0 | let orderings = args.ordering_fields.to_vec(); |
107 | 0 | fields.push(Field::new_list( |
108 | 0 | format_state_name(args.name, "array_agg_orderings"), |
109 | 0 | Field::new("item", DataType::Struct(Fields::from(orderings)), true), |
110 | 0 | false, |
111 | 0 | )); |
112 | 0 |
|
113 | 0 | Ok(fields) |
114 | 0 | } |
115 | | |
116 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
117 | 0 | let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; |
118 | | |
119 | 0 | if acc_args.is_distinct { |
120 | 0 | return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); |
121 | 0 | } |
122 | 0 |
|
123 | 0 | if acc_args.ordering_req.is_empty() { |
124 | 0 | return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); |
125 | 0 | } |
126 | | |
127 | 0 | let ordering_dtypes = acc_args |
128 | 0 | .ordering_req |
129 | 0 | .iter() |
130 | 0 | .map(|e| e.expr.data_type(acc_args.schema)) |
131 | 0 | .collect::<Result<Vec<_>>>()?; |
132 | | |
133 | 0 | OrderSensitiveArrayAggAccumulator::try_new( |
134 | 0 | &data_type, |
135 | 0 | &ordering_dtypes, |
136 | 0 | acc_args.ordering_req.to_vec(), |
137 | 0 | acc_args.is_reversed, |
138 | 0 | ) |
139 | 0 | .map(|acc| Box::new(acc) as _) |
140 | 0 | } |
141 | | |
142 | 3 | fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { |
143 | 3 | datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) |
144 | 3 | } |
145 | | } |
146 | | |
147 | | #[derive(Debug)] |
148 | | pub struct ArrayAggAccumulator { |
149 | | values: Vec<ArrayRef>, |
150 | | datatype: DataType, |
151 | | } |
152 | | |
153 | | impl ArrayAggAccumulator { |
154 | | /// new array_agg accumulator based on given item data type |
155 | 0 | pub fn try_new(datatype: &DataType) -> Result<Self> { |
156 | 0 | Ok(Self { |
157 | 0 | values: vec![], |
158 | 0 | datatype: datatype.clone(), |
159 | 0 | }) |
160 | 0 | } |
161 | | } |
162 | | |
163 | | impl Accumulator for ArrayAggAccumulator { |
164 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
165 | 0 | // Append value like Int64Array(1,2,3) |
166 | 0 | if values.is_empty() { |
167 | 0 | return Ok(()); |
168 | 0 | } |
169 | 0 |
|
170 | 0 | if values.len() != 1 { |
171 | 0 | return internal_err!("expects single batch"); |
172 | 0 | } |
173 | 0 |
|
174 | 0 | let val = Arc::clone(&values[0]); |
175 | 0 | if val.len() > 0 { |
176 | 0 | self.values.push(val); |
177 | 0 | } |
178 | 0 | Ok(()) |
179 | 0 | } |
180 | | |
181 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
182 | 0 | // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) |
183 | 0 | if states.is_empty() { |
184 | 0 | return Ok(()); |
185 | 0 | } |
186 | 0 |
|
187 | 0 | if states.len() != 1 { |
188 | 0 | return internal_err!("expects single state"); |
189 | 0 | } |
190 | | |
191 | 0 | let list_arr = as_list_array(&states[0])?; |
192 | 0 | for arr in list_arr.iter().flatten() { |
193 | 0 | self.values.push(arr); |
194 | 0 | } |
195 | 0 | Ok(()) |
196 | 0 | } |
197 | | |
198 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
199 | 0 | Ok(vec![self.evaluate()?]) |
200 | 0 | } |
201 | | |
202 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
203 | 0 | // Transform Vec<ListArr> to ListArr |
204 | 0 | let element_arrays: Vec<&dyn Array> = |
205 | 0 | self.values.iter().map(|a| a.as_ref()).collect(); |
206 | 0 |
|
207 | 0 | if element_arrays.is_empty() { |
208 | 0 | return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); |
209 | 0 | } |
210 | | |
211 | 0 | let concated_array = arrow::compute::concat(&element_arrays)?; |
212 | 0 | let list_array = array_into_list_array_nullable(concated_array); |
213 | 0 |
|
214 | 0 | Ok(ScalarValue::List(Arc::new(list_array))) |
215 | 0 | } |
216 | | |
217 | 0 | fn size(&self) -> usize { |
218 | 0 | std::mem::size_of_val(self) |
219 | 0 | + (std::mem::size_of::<ArrayRef>() * self.values.capacity()) |
220 | 0 | + self |
221 | 0 | .values |
222 | 0 | .iter() |
223 | 0 | .map(|arr| arr.get_array_memory_size()) |
224 | 0 | .sum::<usize>() |
225 | 0 | + self.datatype.size() |
226 | 0 | - std::mem::size_of_val(&self.datatype) |
227 | 0 | } |
228 | | } |
229 | | |
230 | | #[derive(Debug)] |
231 | | struct DistinctArrayAggAccumulator { |
232 | | values: HashSet<ScalarValue>, |
233 | | datatype: DataType, |
234 | | } |
235 | | |
236 | | impl DistinctArrayAggAccumulator { |
237 | 0 | pub fn try_new(datatype: &DataType) -> Result<Self> { |
238 | 0 | Ok(Self { |
239 | 0 | values: HashSet::new(), |
240 | 0 | datatype: datatype.clone(), |
241 | 0 | }) |
242 | 0 | } |
243 | | } |
244 | | |
245 | | impl Accumulator for DistinctArrayAggAccumulator { |
246 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
247 | 0 | Ok(vec![self.evaluate()?]) |
248 | 0 | } |
249 | | |
250 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
251 | 0 | if values.len() != 1 { |
252 | 0 | return internal_err!("expects single batch"); |
253 | 0 | } |
254 | 0 |
|
255 | 0 | let array = &values[0]; |
256 | | |
257 | 0 | for i in 0..array.len() { |
258 | 0 | let scalar = ScalarValue::try_from_array(&array, i)?; |
259 | 0 | self.values.insert(scalar); |
260 | | } |
261 | | |
262 | 0 | Ok(()) |
263 | 0 | } |
264 | | |
265 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
266 | 0 | if states.is_empty() { |
267 | 0 | return Ok(()); |
268 | 0 | } |
269 | 0 |
|
270 | 0 | if states.len() != 1 { |
271 | 0 | return internal_err!("expects single state"); |
272 | 0 | } |
273 | 0 |
|
274 | 0 | states[0] |
275 | 0 | .as_list::<i32>() |
276 | 0 | .iter() |
277 | 0 | .flatten() |
278 | 0 | .try_for_each(|val| self.update_batch(&[val])) |
279 | 0 | } |
280 | | |
281 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
282 | 0 | let values: Vec<ScalarValue> = self.values.iter().cloned().collect(); |
283 | 0 | if values.is_empty() { |
284 | 0 | return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); |
285 | 0 | } |
286 | 0 | let arr = ScalarValue::new_list(&values, &self.datatype, true); |
287 | 0 | Ok(ScalarValue::List(arr)) |
288 | 0 | } |
289 | | |
290 | 0 | fn size(&self) -> usize { |
291 | 0 | std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) |
292 | 0 | - std::mem::size_of_val(&self.values) |
293 | 0 | + self.datatype.size() |
294 | 0 | - std::mem::size_of_val(&self.datatype) |
295 | 0 | } |
296 | | } |
297 | | |
298 | | /// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi |
299 | | /// partition setting, partial aggregations are computed for every partition, |
300 | | /// and then their results are merged. |
301 | | #[derive(Debug)] |
302 | | pub(crate) struct OrderSensitiveArrayAggAccumulator { |
303 | | /// Stores entries in the `ARRAY_AGG` result. |
304 | | values: Vec<ScalarValue>, |
305 | | /// Stores values of ordering requirement expressions corresponding to each |
306 | | /// entry in `values`. This information is used when merging results from |
307 | | /// different partitions. For detailed information how merging is done, see |
308 | | /// [`merge_ordered_arrays`]. |
309 | | ordering_values: Vec<Vec<ScalarValue>>, |
310 | | /// Stores datatypes of expressions inside values and ordering requirement |
311 | | /// expressions. |
312 | | datatypes: Vec<DataType>, |
313 | | /// Stores the ordering requirement of the `Accumulator`. |
314 | | ordering_req: LexOrdering, |
315 | | /// Whether the aggregation is running in reverse. |
316 | | reverse: bool, |
317 | | } |
318 | | |
319 | | impl OrderSensitiveArrayAggAccumulator { |
320 | | /// Create a new order-sensitive ARRAY_AGG accumulator based on the given |
321 | | /// item data type. |
322 | 0 | pub fn try_new( |
323 | 0 | datatype: &DataType, |
324 | 0 | ordering_dtypes: &[DataType], |
325 | 0 | ordering_req: LexOrdering, |
326 | 0 | reverse: bool, |
327 | 0 | ) -> Result<Self> { |
328 | 0 | let mut datatypes = vec![datatype.clone()]; |
329 | 0 | datatypes.extend(ordering_dtypes.iter().cloned()); |
330 | 0 | Ok(Self { |
331 | 0 | values: vec![], |
332 | 0 | ordering_values: vec![], |
333 | 0 | datatypes, |
334 | 0 | ordering_req, |
335 | 0 | reverse, |
336 | 0 | }) |
337 | 0 | } |
338 | | } |
339 | | |
340 | | impl Accumulator for OrderSensitiveArrayAggAccumulator { |
341 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
342 | 0 | if values.is_empty() { |
343 | 0 | return Ok(()); |
344 | 0 | } |
345 | 0 |
|
346 | 0 | let n_row = values[0].len(); |
347 | 0 | for index in 0..n_row { |
348 | 0 | let row = get_row_at_idx(values, index)?; |
349 | 0 | self.values.push(row[0].clone()); |
350 | 0 | self.ordering_values.push(row[1..].to_vec()); |
351 | | } |
352 | | |
353 | 0 | Ok(()) |
354 | 0 | } |
355 | | |
356 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
357 | 0 | if states.is_empty() { |
358 | 0 | return Ok(()); |
359 | 0 | } |
360 | | |
361 | | // First entry in the state is the aggregation result. Second entry |
362 | | // stores values received for ordering requirement columns for each |
363 | | // aggregation value inside `ARRAY_AGG` list. For each `StructArray` |
364 | | // inside `ARRAY_AGG` list, we will receive an `Array` that stores values |
365 | | // received from its ordering requirement expression. (This information |
366 | | // is necessary for during merging). |
367 | 0 | let [array_agg_values, agg_orderings, ..] = &states else { |
368 | 0 | return exec_err!("State should have two elements"); |
369 | | }; |
370 | 0 | let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else { |
371 | 0 | return exec_err!("Expects to receive a list array"); |
372 | | }; |
373 | | |
374 | | // Stores ARRAY_AGG results coming from each partition |
375 | 0 | let mut partition_values = vec![]; |
376 | 0 | // Stores ordering requirement expression results coming from each partition |
377 | 0 | let mut partition_ordering_values = vec![]; |
378 | 0 |
|
379 | 0 | // Existing values should be merged also. |
380 | 0 | partition_values.push(self.values.clone().into()); |
381 | 0 | partition_ordering_values.push(self.ordering_values.clone().into()); |
382 | | |
383 | | // Convert array to Scalars to sort them easily. Convert back to array at evaluation. |
384 | 0 | let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; |
385 | 0 | for v in array_agg_res.into_iter() { |
386 | 0 | partition_values.push(v.into()); |
387 | 0 | } |
388 | | |
389 | 0 | let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; |
390 | | |
391 | 0 | for partition_ordering_rows in orderings.into_iter() { |
392 | | // Extract value from struct to ordering_rows for each group/partition |
393 | 0 | let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { |
394 | 0 | if let ScalarValue::Struct(s) = ordering_row { |
395 | 0 | let mut ordering_columns_per_row = vec![]; |
396 | | |
397 | 0 | for column in s.columns() { |
398 | 0 | let sv = ScalarValue::try_from_array(column, 0)?; |
399 | 0 | ordering_columns_per_row.push(sv); |
400 | | } |
401 | | |
402 | 0 | Ok(ordering_columns_per_row) |
403 | | } else { |
404 | 0 | exec_err!( |
405 | 0 | "Expects to receive ScalarValue::Struct(Arc<StructArray>) but got:{:?}", |
406 | 0 | ordering_row.data_type() |
407 | 0 | ) |
408 | | } |
409 | 0 | }).collect::<Result<VecDeque<_>>>()?; |
410 | | |
411 | 0 | partition_ordering_values.push(ordering_value); |
412 | | } |
413 | | |
414 | 0 | let sort_options = self |
415 | 0 | .ordering_req |
416 | 0 | .iter() |
417 | 0 | .map(|sort_expr| sort_expr.options) |
418 | 0 | .collect::<Vec<_>>(); |
419 | | |
420 | 0 | (self.values, self.ordering_values) = merge_ordered_arrays( |
421 | 0 | &mut partition_values, |
422 | 0 | &mut partition_ordering_values, |
423 | 0 | &sort_options, |
424 | 0 | )?; |
425 | | |
426 | 0 | Ok(()) |
427 | 0 | } |
428 | | |
429 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
430 | 0 | let mut result = vec![self.evaluate()?]; |
431 | 0 | result.push(self.evaluate_orderings()?); |
432 | | |
433 | 0 | Ok(result) |
434 | 0 | } |
435 | | |
436 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
437 | 0 | if self.values.is_empty() { |
438 | 0 | return Ok(ScalarValue::new_null_list( |
439 | 0 | self.datatypes[0].clone(), |
440 | 0 | true, |
441 | 0 | 1, |
442 | 0 | )); |
443 | 0 | } |
444 | 0 |
|
445 | 0 | let values = self.values.clone(); |
446 | 0 | let array = if self.reverse { |
447 | 0 | ScalarValue::new_list_from_iter( |
448 | 0 | values.into_iter().rev(), |
449 | 0 | &self.datatypes[0], |
450 | 0 | true, |
451 | 0 | ) |
452 | | } else { |
453 | 0 | ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) |
454 | | }; |
455 | 0 | Ok(ScalarValue::List(array)) |
456 | 0 | } |
457 | | |
458 | 0 | fn size(&self) -> usize { |
459 | 0 | let mut total = std::mem::size_of_val(self) |
460 | 0 | + ScalarValue::size_of_vec(&self.values) |
461 | 0 | - std::mem::size_of_val(&self.values); |
462 | 0 |
|
463 | 0 | // Add size of the `self.ordering_values` |
464 | 0 | total += |
465 | 0 | std::mem::size_of::<Vec<ScalarValue>>() * self.ordering_values.capacity(); |
466 | 0 | for row in &self.ordering_values { |
467 | 0 | total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); |
468 | 0 | } |
469 | | |
470 | | // Add size of the `self.datatypes` |
471 | 0 | total += std::mem::size_of::<DataType>() * self.datatypes.capacity(); |
472 | 0 | for dtype in &self.datatypes { |
473 | 0 | total += dtype.size() - std::mem::size_of_val(dtype); |
474 | 0 | } |
475 | | |
476 | | // Add size of the `self.ordering_req` |
477 | 0 | total += std::mem::size_of::<PhysicalSortExpr>() * self.ordering_req.capacity(); |
478 | 0 | // TODO: Calculate size of each `PhysicalSortExpr` more accurately. |
479 | 0 | total |
480 | 0 | } |
481 | | } |
482 | | |
483 | | impl OrderSensitiveArrayAggAccumulator { |
484 | 0 | fn evaluate_orderings(&self) -> Result<ScalarValue> { |
485 | 0 | let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); |
486 | 0 | let num_columns = fields.len(); |
487 | 0 | let struct_field = Fields::from(fields.clone()); |
488 | 0 |
|
489 | 0 | let mut column_wise_ordering_values = vec![]; |
490 | 0 | for i in 0..num_columns { |
491 | 0 | let column_values = self |
492 | 0 | .ordering_values |
493 | 0 | .iter() |
494 | 0 | .map(|x| x[i].clone()) |
495 | 0 | .collect::<Vec<_>>(); |
496 | 0 | let array = if column_values.is_empty() { |
497 | 0 | new_empty_array(fields[i].data_type()) |
498 | | } else { |
499 | 0 | ScalarValue::iter_to_array(column_values.into_iter())? |
500 | | }; |
501 | 0 | column_wise_ordering_values.push(array); |
502 | | } |
503 | | |
504 | 0 | let ordering_array = |
505 | 0 | StructArray::try_new(struct_field, column_wise_ordering_values, None)?; |
506 | 0 | Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( |
507 | 0 | Arc::new(ordering_array), |
508 | 0 | )))) |
509 | 0 | } |
510 | | } |
511 | | |
512 | | #[cfg(test)] |
513 | | mod tests { |
514 | | use super::*; |
515 | | |
516 | | use std::collections::VecDeque; |
517 | | use std::sync::Arc; |
518 | | |
519 | | use arrow::array::Int64Array; |
520 | | use arrow_schema::SortOptions; |
521 | | |
522 | | use datafusion_common::utils::get_row_at_idx; |
523 | | use datafusion_common::{Result, ScalarValue}; |
524 | | |
525 | | #[test] |
526 | | fn test_merge_asc() -> Result<()> { |
527 | | let lhs_arrays: Vec<ArrayRef> = vec![ |
528 | | Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), |
529 | | Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), |
530 | | ]; |
531 | | let n_row = lhs_arrays[0].len(); |
532 | | let lhs_orderings = (0..n_row) |
533 | | .map(|idx| get_row_at_idx(&lhs_arrays, idx)) |
534 | | .collect::<Result<VecDeque<_>>>()?; |
535 | | |
536 | | let rhs_arrays: Vec<ArrayRef> = vec![ |
537 | | Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), |
538 | | Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])), |
539 | | ]; |
540 | | let n_row = rhs_arrays[0].len(); |
541 | | let rhs_orderings = (0..n_row) |
542 | | .map(|idx| get_row_at_idx(&rhs_arrays, idx)) |
543 | | .collect::<Result<VecDeque<_>>>()?; |
544 | | let sort_options = vec![ |
545 | | SortOptions { |
546 | | descending: false, |
547 | | nulls_first: false, |
548 | | }, |
549 | | SortOptions { |
550 | | descending: false, |
551 | | nulls_first: false, |
552 | | }, |
553 | | ]; |
554 | | |
555 | | let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; |
556 | | let lhs_vals = (0..lhs_vals_arr.len()) |
557 | | .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) |
558 | | .collect::<Result<VecDeque<_>>>()?; |
559 | | |
560 | | let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; |
561 | | let rhs_vals = (0..rhs_vals_arr.len()) |
562 | | .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) |
563 | | .collect::<Result<VecDeque<_>>>()?; |
564 | | let expected = |
565 | | Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; |
566 | | let expected_ts = vec![ |
567 | | Arc::new(Int64Array::from(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2])) as ArrayRef, |
568 | | Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef, |
569 | | ]; |
570 | | |
571 | | let (merged_vals, merged_ts) = merge_ordered_arrays( |
572 | | &mut [lhs_vals, rhs_vals], |
573 | | &mut [lhs_orderings, rhs_orderings], |
574 | | &sort_options, |
575 | | )?; |
576 | | let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; |
577 | | let merged_ts = (0..merged_ts[0].len()) |
578 | | .map(|col_idx| { |
579 | | ScalarValue::iter_to_array( |
580 | | (0..merged_ts.len()) |
581 | | .map(|row_idx| merged_ts[row_idx][col_idx].clone()), |
582 | | ) |
583 | | }) |
584 | | .collect::<Result<Vec<_>>>()?; |
585 | | |
586 | | assert_eq!(&merged_vals, &expected); |
587 | | assert_eq!(&merged_ts, &expected_ts); |
588 | | |
589 | | Ok(()) |
590 | | } |
591 | | |
592 | | #[test] |
593 | | fn test_merge_desc() -> Result<()> { |
594 | | let lhs_arrays: Vec<ArrayRef> = vec![ |
595 | | Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), |
596 | | Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), |
597 | | ]; |
598 | | let n_row = lhs_arrays[0].len(); |
599 | | let lhs_orderings = (0..n_row) |
600 | | .map(|idx| get_row_at_idx(&lhs_arrays, idx)) |
601 | | .collect::<Result<VecDeque<_>>>()?; |
602 | | |
603 | | let rhs_arrays: Vec<ArrayRef> = vec![ |
604 | | Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), |
605 | | Arc::new(Int64Array::from(vec![4, 3, 2, 1, 0])), |
606 | | ]; |
607 | | let n_row = rhs_arrays[0].len(); |
608 | | let rhs_orderings = (0..n_row) |
609 | | .map(|idx| get_row_at_idx(&rhs_arrays, idx)) |
610 | | .collect::<Result<VecDeque<_>>>()?; |
611 | | let sort_options = vec![ |
612 | | SortOptions { |
613 | | descending: true, |
614 | | nulls_first: false, |
615 | | }, |
616 | | SortOptions { |
617 | | descending: true, |
618 | | nulls_first: false, |
619 | | }, |
620 | | ]; |
621 | | |
622 | | // Values (which will be merged) doesn't have to be ordered. |
623 | | let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; |
624 | | let lhs_vals = (0..lhs_vals_arr.len()) |
625 | | .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) |
626 | | .collect::<Result<VecDeque<_>>>()?; |
627 | | |
628 | | let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; |
629 | | let rhs_vals = (0..rhs_vals_arr.len()) |
630 | | .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) |
631 | | .collect::<Result<VecDeque<_>>>()?; |
632 | | let expected = |
633 | | Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; |
634 | | let expected_ts = vec![ |
635 | | Arc::new(Int64Array::from(vec![2, 2, 1, 1, 1, 1, 0, 0, 0, 0])) as ArrayRef, |
636 | | Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, |
637 | | ]; |
638 | | let (merged_vals, merged_ts) = merge_ordered_arrays( |
639 | | &mut [lhs_vals, rhs_vals], |
640 | | &mut [lhs_orderings, rhs_orderings], |
641 | | &sort_options, |
642 | | )?; |
643 | | let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; |
644 | | let merged_ts = (0..merged_ts[0].len()) |
645 | | .map(|col_idx| { |
646 | | ScalarValue::iter_to_array( |
647 | | (0..merged_ts.len()) |
648 | | .map(|row_idx| merged_ts[row_idx][col_idx].clone()), |
649 | | ) |
650 | | }) |
651 | | .collect::<Result<Vec<_>>>()?; |
652 | | |
653 | | assert_eq!(&merged_vals, &expected); |
654 | | assert_eq!(&merged_ts, &expected_ts); |
655 | | Ok(()) |
656 | | } |
657 | | } |