/Users/andrewlamb/Software/datafusion/datafusion/common/src/utils/mod.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! This module provides the bisect function, which implements binary search. |
19 | | |
20 | | pub mod expr; |
21 | | pub mod memory; |
22 | | pub mod proxy; |
23 | | pub mod string_utils; |
24 | | |
25 | | use crate::error::{_internal_datafusion_err, _internal_err}; |
26 | | use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; |
27 | | use arrow::array::{ArrayRef, PrimitiveArray}; |
28 | | use arrow::buffer::OffsetBuffer; |
29 | | use arrow::compute; |
30 | | use arrow::compute::{partition, SortColumn, SortOptions}; |
31 | | use arrow::datatypes::{Field, SchemaRef, UInt32Type}; |
32 | | use arrow::record_batch::RecordBatch; |
33 | | use arrow_array::cast::AsArray; |
34 | | use arrow_array::{ |
35 | | Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, |
36 | | RecordBatchOptions, |
37 | | }; |
38 | | use arrow_schema::DataType; |
39 | | use sqlparser::ast::Ident; |
40 | | use sqlparser::dialect::GenericDialect; |
41 | | use sqlparser::parser::Parser; |
42 | | use std::borrow::{Borrow, Cow}; |
43 | | use std::cmp::{min, Ordering}; |
44 | | use std::collections::HashSet; |
45 | | use std::ops::Range; |
46 | | use std::sync::Arc; |
47 | | |
48 | | /// Applies an optional projection to a [`SchemaRef`], returning the |
49 | | /// projected schema |
50 | | /// |
51 | | /// Example: |
52 | | /// ``` |
53 | | /// use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; |
54 | | /// use datafusion_common::project_schema; |
55 | | /// |
56 | | /// // Schema with columns 'a', 'b', and 'c' |
57 | | /// let schema = SchemaRef::new(Schema::new(vec![ |
58 | | /// Field::new("a", DataType::Int32, true), |
59 | | /// Field::new("b", DataType::Int64, true), |
60 | | /// Field::new("c", DataType::Utf8, true), |
61 | | /// ])); |
62 | | /// |
63 | | /// // Pick columns 'c' and 'b' |
64 | | /// let projection = Some(vec![2,1]); |
65 | | /// let projected_schema = project_schema( |
66 | | /// &schema, |
67 | | /// projection.as_ref() |
68 | | /// ).unwrap(); |
69 | | /// |
70 | | /// let expected_schema = SchemaRef::new(Schema::new(vec![ |
71 | | /// Field::new("c", DataType::Utf8, true), |
72 | | /// Field::new("b", DataType::Int64, true), |
73 | | /// ])); |
74 | | /// |
75 | | /// assert_eq!(projected_schema, expected_schema); |
76 | | /// ``` |
77 | 1.17k | pub fn project_schema( |
78 | 1.17k | schema: &SchemaRef, |
79 | 1.17k | projection: Option<&Vec<usize>>, |
80 | 1.17k | ) -> Result<SchemaRef> { |
81 | 1.17k | let schema = match projection { |
82 | 0 | Some(columns) => Arc::new(schema.project(columns)?), |
83 | 1.17k | None => Arc::clone(schema), |
84 | | }; |
85 | 1.17k | Ok(schema) |
86 | 1.17k | } |
87 | | |
88 | | /// Given column vectors, returns row at `idx`. |
89 | 6.21k | pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result<Vec<ScalarValue>> { |
90 | 6.21k | columns |
91 | 6.21k | .iter() |
92 | 6.43k | .map(|arr| ScalarValue::try_from_array(arr, idx)) |
93 | 6.21k | .collect() |
94 | 6.21k | } |
95 | | |
96 | | /// Construct a new RecordBatch from the rows of the `record_batch` at the `indices`. |
97 | 5 | pub fn get_record_batch_at_indices( |
98 | 5 | record_batch: &RecordBatch, |
99 | 5 | indices: &PrimitiveArray<UInt32Type>, |
100 | 5 | ) -> Result<RecordBatch> { |
101 | 5 | let new_columns = take_arrays(record_batch.columns(), indices)?0 ; |
102 | 5 | RecordBatch::try_new_with_options( |
103 | 5 | record_batch.schema(), |
104 | 5 | new_columns, |
105 | 5 | &RecordBatchOptions::new().with_row_count(Some(indices.len())), |
106 | 5 | ) |
107 | 5 | .map_err(|e| arrow_datafusion_err!(e)0 ) |
108 | 5 | } |
109 | | |
110 | | /// This function compares two tuples depending on the given sort options. |
111 | 6.02k | pub fn compare_rows( |
112 | 6.02k | x: &[ScalarValue], |
113 | 6.02k | y: &[ScalarValue], |
114 | 6.02k | sort_options: &[SortOptions], |
115 | 6.02k | ) -> Result<Ordering> { |
116 | 6.02k | let zip_it = x.iter().zip(y.iter()).zip(sort_options.iter()); |
117 | | // Preserving lexical ordering. |
118 | 6.18k | for ((lhs, rhs), sort_options6.02k ) in zip_it { |
119 | | // Consider all combinations of NULLS FIRST/LAST and ASC/DESC configurations. |
120 | 6.02k | let result = match (lhs.is_null(), rhs.is_null(), sort_options.nulls_first) { |
121 | 0 | (true, false, false) | (false, true, true) => Ordering::Greater, |
122 | 18 | (true, false, true) | (false, true, false) => Ordering::Less, |
123 | 6.01k | (false, false, _) => if sort_options.descending { |
124 | 1.47k | rhs.partial_cmp(lhs) |
125 | | } else { |
126 | 4.53k | lhs.partial_cmp(rhs) |
127 | | } |
128 | 6.01k | .ok_or_else(|| { |
129 | 0 | _internal_datafusion_err!("Column array shouldn't be empty") |
130 | 6.01k | })?0 , |
131 | 0 | (true, true, _) => continue, |
132 | | }; |
133 | 6.02k | if result != Ordering::Equal { |
134 | 5.87k | return Ok(result); |
135 | 154 | } |
136 | | } |
137 | 154 | Ok(Ordering::Equal) |
138 | 6.02k | } |
139 | | |
140 | | /// This function searches for a tuple of given values (`target`) among the given |
141 | | /// rows (`item_columns`) using the bisection algorithm. It assumes that `item_columns` |
142 | | /// is sorted according to `sort_options` and returns the insertion index of `target`. |
143 | | /// Template argument `SIDE` being `true`/`false` means left/right insertion. |
144 | 3.66k | pub fn bisect<const SIDE: bool>( |
145 | 3.66k | item_columns: &[ArrayRef], |
146 | 3.66k | target: &[ScalarValue], |
147 | 3.66k | sort_options: &[SortOptions], |
148 | 3.66k | ) -> Result<usize> { |
149 | 3.66k | let low: usize = 0; |
150 | 3.66k | let high: usize = item_columns |
151 | 3.66k | .first() |
152 | 3.66k | .ok_or_else(|| { |
153 | 0 | DataFusionError::Internal("Column array shouldn't be empty".to_string()) |
154 | 3.66k | })?0 |
155 | 3.66k | .len(); |
156 | 5.95k | let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { |
157 | 5.95k | let cmp = compare_rows(current, target, sort_options)?0 ; |
158 | 5.95k | Ok(if SIDE { cmp.is_lt() } else { cmp.is_le()0 }) |
159 | 5.95k | }; |
160 | 3.66k | find_bisect_point(item_columns, target, compare_fn, low, high) |
161 | 3.66k | } |
162 | | |
163 | | /// This function searches for a tuple of given values (`target`) among a slice of |
164 | | /// the given rows (`item_columns`) using the bisection algorithm. The slice starts |
165 | | /// at the index `low` and ends at the index `high`. The boolean-valued function |
166 | | /// `compare_fn` specifies whether we bisect on the left (by returning `false`), |
167 | | /// or on the right (by returning `true`) when we compare the target value with |
168 | | /// the current value as we iteratively bisect the input. |
169 | 3.66k | pub fn find_bisect_point<F>( |
170 | 3.66k | item_columns: &[ArrayRef], |
171 | 3.66k | target: &[ScalarValue], |
172 | 3.66k | compare_fn: F, |
173 | 3.66k | mut low: usize, |
174 | 3.66k | mut high: usize, |
175 | 3.66k | ) -> Result<usize> |
176 | 3.66k | where |
177 | 3.66k | F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>, |
178 | 3.66k | { |
179 | 9.61k | while low < high { |
180 | 5.95k | let mid = ((high - low) / 2) + low; |
181 | 5.95k | let val = get_row_at_idx(item_columns, mid)?0 ; |
182 | 5.95k | if compare_fn(&val, target)?0 { |
183 | 5.16k | low = mid + 1; |
184 | 5.16k | } else { |
185 | 786 | high = mid; |
186 | 786 | } |
187 | | } |
188 | 3.66k | Ok(low) |
189 | 3.66k | } |
190 | | |
191 | | /// This function searches for a tuple of given values (`target`) among the given |
192 | | /// rows (`item_columns`) via a linear scan. It assumes that `item_columns` is sorted |
193 | | /// according to `sort_options` and returns the insertion index of `target`. |
194 | | /// Template argument `SIDE` being `true`/`false` means left/right insertion. |
195 | 0 | pub fn linear_search<const SIDE: bool>( |
196 | 0 | item_columns: &[ArrayRef], |
197 | 0 | target: &[ScalarValue], |
198 | 0 | sort_options: &[SortOptions], |
199 | 0 | ) -> Result<usize> { |
200 | 0 | let low: usize = 0; |
201 | 0 | let high: usize = item_columns |
202 | 0 | .first() |
203 | 0 | .ok_or_else(|| { |
204 | 0 | DataFusionError::Internal("Column array shouldn't be empty".to_string()) |
205 | 0 | })? |
206 | 0 | .len(); |
207 | 0 | let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { |
208 | 0 | let cmp = compare_rows(current, target, sort_options)?; |
209 | 0 | Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) |
210 | 0 | }; |
211 | 0 | search_in_slice(item_columns, target, compare_fn, low, high) |
212 | 0 | } |
213 | | |
214 | | /// This function searches for a tuple of given values (`target`) among a slice of |
215 | | /// the given rows (`item_columns`) via a linear scan. The slice starts at the index |
216 | | /// `low` and ends at the index `high`. The boolean-valued function `compare_fn` |
217 | | /// specifies the stopping criterion. |
218 | 26 | pub fn search_in_slice<F>( |
219 | 26 | item_columns: &[ArrayRef], |
220 | 26 | target: &[ScalarValue], |
221 | 26 | compare_fn: F, |
222 | 26 | mut low: usize, |
223 | 26 | high: usize, |
224 | 26 | ) -> Result<usize> |
225 | 26 | where |
226 | 26 | F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>, |
227 | 26 | { |
228 | 50 | while low < high { |
229 | 41 | let val = get_row_at_idx(item_columns, low)?0 ; |
230 | 41 | if !compare_fn(&val, target)?0 { |
231 | 17 | break; |
232 | 24 | } |
233 | 24 | low += 1; |
234 | | } |
235 | 26 | Ok(low) |
236 | 26 | } |
237 | | |
238 | | /// Given a list of 0 or more already sorted columns, finds the |
239 | | /// partition ranges that would partition equally across columns. |
240 | | /// |
241 | | /// See [`partition`] for more details. |
242 | 24 | pub fn evaluate_partition_ranges( |
243 | 24 | num_rows: usize, |
244 | 24 | partition_columns: &[SortColumn], |
245 | 24 | ) -> Result<Vec<Range<usize>>> { |
246 | 24 | Ok(if partition_columns.is_empty() { |
247 | 3 | vec![Range { |
248 | 3 | start: 0, |
249 | 3 | end: num_rows, |
250 | 3 | }] |
251 | | } else { |
252 | 21 | let cols: Vec<_> = partition_columns |
253 | 21 | .iter() |
254 | 25 | .map(|x| Arc::clone(&x.values)) |
255 | 21 | .collect(); |
256 | 21 | partition(&cols)?0 .ranges() |
257 | | }) |
258 | 24 | } |
259 | | |
260 | | /// Wraps identifier string in double quotes, escaping any double quotes in |
261 | | /// the identifier by replacing it with two double quotes |
262 | | /// |
263 | | /// e.g. identifier `tab.le"name` becomes `"tab.le""name"` |
264 | 0 | pub fn quote_identifier(s: &str) -> Cow<str> { |
265 | 0 | if needs_quotes(s) { |
266 | 0 | Cow::Owned(format!("\"{}\"", s.replace('"', "\"\""))) |
267 | | } else { |
268 | 0 | Cow::Borrowed(s) |
269 | | } |
270 | 0 | } |
271 | | |
272 | | /// returns true if this identifier needs quotes |
273 | 0 | fn needs_quotes(s: &str) -> bool { |
274 | 0 | let mut chars = s.chars(); |
275 | | |
276 | | // first char can not be a number unless escaped |
277 | 0 | if let Some(first_char) = chars.next() { |
278 | 0 | if !(first_char.is_ascii_lowercase() || first_char == '_') { |
279 | 0 | return true; |
280 | 0 | } |
281 | 0 | } |
282 | | |
283 | 0 | !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') |
284 | 0 | } |
285 | | |
286 | 0 | pub(crate) fn parse_identifiers(s: &str) -> Result<Vec<Ident>> { |
287 | 0 | let dialect = GenericDialect; |
288 | 0 | let mut parser = Parser::new(&dialect).try_with_sql(s)?; |
289 | 0 | let idents = parser.parse_multipart_identifier()?; |
290 | 0 | Ok(idents) |
291 | 0 | } |
292 | | |
293 | | /// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. |
294 | | /// |
295 | | /// TODO: use implementation in arrow-rs when available: |
296 | | /// <https://github.com/apache/arrow-rs/pull/6475> |
297 | 17.1k | pub fn take_arrays(arrays: &[ArrayRef], indices: &dyn Array) -> Result<Vec<ArrayRef>> { |
298 | 17.1k | arrays |
299 | 17.1k | .iter() |
300 | 179k | .map(|array| { |
301 | 179k | compute::take( |
302 | 179k | array.as_ref(), |
303 | 179k | indices, |
304 | 179k | None, // None: no index check |
305 | 179k | ) |
306 | 179k | .map_err(|e| arrow_datafusion_err!(e)0 ) |
307 | 179k | }) |
308 | 17.1k | .collect() |
309 | 17.1k | } |
310 | | |
311 | 0 | pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec<String> { |
312 | 0 | parse_identifiers(s) |
313 | 0 | .unwrap_or_default() |
314 | 0 | .into_iter() |
315 | 0 | .map(|id| match id.quote_style { |
316 | 0 | Some(_) => id.value, |
317 | 0 | None if ignore_case => id.value, |
318 | 0 | _ => id.value.to_ascii_lowercase(), |
319 | 0 | }) |
320 | 0 | .collect::<Vec<_>>() |
321 | 0 | } |
322 | | |
323 | | /// This function "takes" the elements at `indices` from the slice `items`. |
324 | 3 | pub fn get_at_indices<T: Clone, I: Borrow<usize>>( |
325 | 3 | items: &[T], |
326 | 3 | indices: impl IntoIterator<Item = I>, |
327 | 3 | ) -> Result<Vec<T>> { |
328 | 3 | indices |
329 | 3 | .into_iter() |
330 | 3 | .map(|idx| items.get(*idx.borrow()).cloned()0 ) |
331 | 3 | .collect::<Option<Vec<T>>>() |
332 | 3 | .ok_or_else(|| { |
333 | 0 | DataFusionError::Execution( |
334 | 0 | "Expects indices to be in the range of searched vector".to_string(), |
335 | 0 | ) |
336 | 3 | }) |
337 | 3 | } |
338 | | |
339 | | /// This function finds the longest prefix of the form 0, 1, 2, ... within the |
340 | | /// collection `sequence`. Examples: |
341 | | /// - For 0, 1, 2, 4, 5; we would produce 3, meaning 0, 1, 2 is the longest satisfying |
342 | | /// prefix. |
343 | | /// - For 1, 2, 3, 4; we would produce 0, meaning there is no such prefix. |
344 | 0 | pub fn longest_consecutive_prefix<T: Borrow<usize>>( |
345 | 0 | sequence: impl IntoIterator<Item = T>, |
346 | 0 | ) -> usize { |
347 | 0 | let mut count = 0; |
348 | 0 | for item in sequence { |
349 | 0 | if !count.eq(item.borrow()) { |
350 | 0 | break; |
351 | 0 | } |
352 | 0 | count += 1; |
353 | | } |
354 | 0 | count |
355 | 0 | } |
356 | | |
357 | | /// Array Utils |
358 | | |
359 | | /// Wrap an array into a single element `ListArray`. |
360 | | /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` |
361 | | /// The field in the list array is nullable. |
362 | 0 | pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { |
363 | 0 | array_into_list_array(arr, true) |
364 | 0 | } |
365 | | |
366 | | /// Array Utils |
367 | | |
368 | | /// Wrap an array into a single element `ListArray`. |
369 | | /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` |
370 | 0 | pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { |
371 | 0 | let offsets = OffsetBuffer::from_lengths([arr.len()]); |
372 | 0 | ListArray::new( |
373 | 0 | Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), |
374 | 0 | offsets, |
375 | 0 | arr, |
376 | 0 | None, |
377 | 0 | ) |
378 | 0 | } |
379 | | |
380 | | /// Wrap an array into a single element `LargeListArray`. |
381 | | /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` |
382 | 0 | pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { |
383 | 0 | let offsets = OffsetBuffer::from_lengths([arr.len()]); |
384 | 0 | LargeListArray::new( |
385 | 0 | Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), |
386 | 0 | offsets, |
387 | 0 | arr, |
388 | 0 | None, |
389 | 0 | ) |
390 | 0 | } |
391 | | |
392 | 0 | pub fn array_into_fixed_size_list_array( |
393 | 0 | arr: ArrayRef, |
394 | 0 | list_size: usize, |
395 | 0 | ) -> FixedSizeListArray { |
396 | 0 | let list_size = list_size as i32; |
397 | 0 | FixedSizeListArray::new( |
398 | 0 | Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), |
399 | 0 | list_size, |
400 | 0 | arr, |
401 | 0 | None, |
402 | 0 | ) |
403 | 0 | } |
404 | | |
405 | | /// Wrap arrays into a single element `ListArray`. |
406 | | /// |
407 | | /// Example: |
408 | | /// ``` |
409 | | /// use arrow::array::{Int32Array, ListArray, ArrayRef}; |
410 | | /// use arrow::datatypes::{Int32Type, Field}; |
411 | | /// use std::sync::Arc; |
412 | | /// |
413 | | /// let arr1 = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; |
414 | | /// let arr2 = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef; |
415 | | /// |
416 | | /// let list_arr = datafusion_common::utils::arrays_into_list_array([arr1, arr2]).unwrap(); |
417 | | /// |
418 | | /// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>( |
419 | | /// vec![ |
420 | | /// Some(vec![Some(1), Some(2), Some(3)]), |
421 | | /// Some(vec![Some(4), Some(5), Some(6)]), |
422 | | /// ] |
423 | | /// ); |
424 | | /// |
425 | | /// assert_eq!(list_arr, expected); |
426 | 0 | pub fn arrays_into_list_array( |
427 | 0 | arr: impl IntoIterator<Item = ArrayRef>, |
428 | 0 | ) -> Result<ListArray> { |
429 | 0 | let arr = arr.into_iter().collect::<Vec<_>>(); |
430 | 0 | if arr.is_empty() { |
431 | 0 | return _internal_err!("Cannot wrap empty array into list array"); |
432 | 0 | } |
433 | 0 |
|
434 | 0 | let lens = arr.iter().map(|x| x.len()).collect::<Vec<_>>(); |
435 | 0 | // Assume data type is consistent |
436 | 0 | let data_type = arr[0].data_type().to_owned(); |
437 | 0 | let values = arr.iter().map(|x| x.as_ref()).collect::<Vec<_>>(); |
438 | 0 | Ok(ListArray::new( |
439 | 0 | Arc::new(Field::new_list_field(data_type, true)), |
440 | 0 | OffsetBuffer::from_lengths(lens), |
441 | 0 | arrow::compute::concat(values.as_slice())?, |
442 | 0 | None, |
443 | | )) |
444 | 0 | } |
445 | | |
446 | | /// Helper function to convert a ListArray into a vector of ArrayRefs. |
447 | 0 | pub fn list_to_arrays<O: OffsetSizeTrait>(a: &ArrayRef) -> Vec<ArrayRef> { |
448 | 0 | a.as_list::<O>().iter().flatten().collect::<Vec<_>>() |
449 | 0 | } |
450 | | |
451 | | /// Helper function to convert a FixedSizeListArray into a vector of ArrayRefs. |
452 | 0 | pub fn fixed_size_list_to_arrays(a: &ArrayRef) -> Vec<ArrayRef> { |
453 | 0 | a.as_fixed_size_list().iter().flatten().collect::<Vec<_>>() |
454 | 0 | } |
455 | | |
456 | | /// Get the base type of a data type. |
457 | | /// |
458 | | /// Example |
459 | | /// ``` |
460 | | /// use arrow::datatypes::{DataType, Field}; |
461 | | /// use datafusion_common::utils::base_type; |
462 | | /// use std::sync::Arc; |
463 | | /// |
464 | | /// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); |
465 | | /// assert_eq!(base_type(&data_type), DataType::Int32); |
466 | | /// |
467 | | /// let data_type = DataType::Int32; |
468 | | /// assert_eq!(base_type(&data_type), DataType::Int32); |
469 | | /// ``` |
470 | 0 | pub fn base_type(data_type: &DataType) -> DataType { |
471 | 0 | match data_type { |
472 | 0 | DataType::List(field) |
473 | 0 | | DataType::LargeList(field) |
474 | 0 | | DataType::FixedSizeList(field, _) => base_type(field.data_type()), |
475 | 0 | _ => data_type.to_owned(), |
476 | | } |
477 | 0 | } |
478 | | |
479 | | /// A helper function to coerce base type in List. |
480 | | /// |
481 | | /// Example |
482 | | /// ``` |
483 | | /// use arrow::datatypes::{DataType, Field}; |
484 | | /// use datafusion_common::utils::coerced_type_with_base_type_only; |
485 | | /// use std::sync::Arc; |
486 | | /// |
487 | | /// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); |
488 | | /// let base_type = DataType::Float64; |
489 | | /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); |
490 | | /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); |
491 | 0 | pub fn coerced_type_with_base_type_only( |
492 | 0 | data_type: &DataType, |
493 | 0 | base_type: &DataType, |
494 | 0 | ) -> DataType { |
495 | 0 | match data_type { |
496 | 0 | DataType::List(field) | DataType::FixedSizeList(field, _) => { |
497 | 0 | let field_type = |
498 | 0 | coerced_type_with_base_type_only(field.data_type(), base_type); |
499 | 0 |
|
500 | 0 | DataType::List(Arc::new(Field::new( |
501 | 0 | field.name(), |
502 | 0 | field_type, |
503 | 0 | field.is_nullable(), |
504 | 0 | ))) |
505 | | } |
506 | 0 | DataType::LargeList(field) => { |
507 | 0 | let field_type = |
508 | 0 | coerced_type_with_base_type_only(field.data_type(), base_type); |
509 | 0 |
|
510 | 0 | DataType::LargeList(Arc::new(Field::new( |
511 | 0 | field.name(), |
512 | 0 | field_type, |
513 | 0 | field.is_nullable(), |
514 | 0 | ))) |
515 | | } |
516 | | |
517 | 0 | _ => base_type.clone(), |
518 | | } |
519 | 0 | } |
520 | | |
521 | | /// Recursively coerce and `FixedSizeList` elements to `List` |
522 | 0 | pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { |
523 | 0 | match data_type { |
524 | 0 | DataType::List(field) | DataType::FixedSizeList(field, _) => { |
525 | 0 | let field_type = coerced_fixed_size_list_to_list(field.data_type()); |
526 | 0 |
|
527 | 0 | DataType::List(Arc::new(Field::new( |
528 | 0 | field.name(), |
529 | 0 | field_type, |
530 | 0 | field.is_nullable(), |
531 | 0 | ))) |
532 | | } |
533 | 0 | DataType::LargeList(field) => { |
534 | 0 | let field_type = coerced_fixed_size_list_to_list(field.data_type()); |
535 | 0 |
|
536 | 0 | DataType::LargeList(Arc::new(Field::new( |
537 | 0 | field.name(), |
538 | 0 | field_type, |
539 | 0 | field.is_nullable(), |
540 | 0 | ))) |
541 | | } |
542 | | |
543 | 0 | _ => data_type.clone(), |
544 | | } |
545 | 0 | } |
546 | | |
547 | | /// Compute the number of dimensions in a list data type. |
548 | 0 | pub fn list_ndims(data_type: &DataType) -> u64 { |
549 | 0 | match data_type { |
550 | 0 | DataType::List(field) |
551 | 0 | | DataType::LargeList(field) |
552 | 0 | | DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()), |
553 | 0 | _ => 0, |
554 | | } |
555 | 0 | } |
556 | | |
557 | | /// Adopted from strsim-rs for string similarity metrics |
558 | | pub mod datafusion_strsim { |
559 | | // Source: https://github.com/dguo/strsim-rs/blob/master/src/lib.rs |
560 | | // License: https://github.com/dguo/strsim-rs/blob/master/LICENSE |
561 | | use std::cmp::min; |
562 | | use std::str::Chars; |
563 | | |
564 | | struct StringWrapper<'a>(&'a str); |
565 | | |
566 | | impl<'a, 'b> IntoIterator for &'a StringWrapper<'b> { |
567 | | type Item = char; |
568 | | type IntoIter = Chars<'b>; |
569 | | |
570 | | fn into_iter(self) -> Self::IntoIter { |
571 | | self.0.chars() |
572 | | } |
573 | | } |
574 | | |
575 | | /// Calculates the minimum number of insertions, deletions, and substitutions |
576 | | /// required to change one sequence into the other. |
577 | | fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( |
578 | | a: &'a Iter1, |
579 | | b: &'b Iter2, |
580 | | ) -> usize |
581 | | where |
582 | | &'a Iter1: IntoIterator<Item = Elem1>, |
583 | | &'b Iter2: IntoIterator<Item = Elem2>, |
584 | | Elem1: PartialEq<Elem2>, |
585 | | { |
586 | | let b_len = b.into_iter().count(); |
587 | | |
588 | | if a.into_iter().next().is_none() { |
589 | | return b_len; |
590 | | } |
591 | | |
592 | | let mut cache: Vec<usize> = (1..b_len + 1).collect(); |
593 | | |
594 | | let mut result = 0; |
595 | | |
596 | | for (i, a_elem) in a.into_iter().enumerate() { |
597 | | result = i + 1; |
598 | | let mut distance_b = i; |
599 | | |
600 | | for (j, b_elem) in b.into_iter().enumerate() { |
601 | | let cost = if a_elem == b_elem { 0usize } else { 1usize }; |
602 | | let distance_a = distance_b + cost; |
603 | | distance_b = cache[j]; |
604 | | result = min(result + 1, min(distance_a, distance_b + 1)); |
605 | | cache[j] = result; |
606 | | } |
607 | | } |
608 | | |
609 | | result |
610 | | } |
611 | | |
612 | | /// Calculates the minimum number of insertions, deletions, and substitutions |
613 | | /// required to change one string into the other. |
614 | | /// |
615 | | /// ``` |
616 | | /// use datafusion_common::utils::datafusion_strsim::levenshtein; |
617 | | /// |
618 | | /// assert_eq!(3, levenshtein("kitten", "sitting")); |
619 | | /// ``` |
620 | | pub fn levenshtein(a: &str, b: &str) -> usize { |
621 | | generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) |
622 | | } |
623 | | } |
624 | | |
625 | | /// Merges collections `first` and `second`, removes duplicates and sorts the |
626 | | /// result, returning it as a [`Vec`]. |
627 | 0 | pub fn merge_and_order_indices<T: Borrow<usize>, S: Borrow<usize>>( |
628 | 0 | first: impl IntoIterator<Item = T>, |
629 | 0 | second: impl IntoIterator<Item = S>, |
630 | 0 | ) -> Vec<usize> { |
631 | 0 | let mut result: Vec<_> = first |
632 | 0 | .into_iter() |
633 | 0 | .map(|e| *e.borrow()) |
634 | 0 | .chain(second.into_iter().map(|e| *e.borrow())) |
635 | 0 | .collect::<HashSet<_>>() |
636 | 0 | .into_iter() |
637 | 0 | .collect(); |
638 | 0 | result.sort(); |
639 | 0 | result |
640 | 0 | } |
641 | | |
642 | | /// Calculates the set difference between sequences `first` and `second`, |
643 | | /// returning the result as a [`Vec`]. Preserves the ordering of `first`. |
644 | 0 | pub fn set_difference<T: Borrow<usize>, S: Borrow<usize>>( |
645 | 0 | first: impl IntoIterator<Item = T>, |
646 | 0 | second: impl IntoIterator<Item = S>, |
647 | 0 | ) -> Vec<usize> { |
648 | 0 | let set: HashSet<_> = second.into_iter().map(|e| *e.borrow()).collect(); |
649 | 0 | first |
650 | 0 | .into_iter() |
651 | 0 | .map(|e| *e.borrow()) |
652 | 0 | .filter(|e| !set.contains(e)) |
653 | 0 | .collect() |
654 | 0 | } |
655 | | |
656 | | /// Checks whether the given index sequence is monotonically non-decreasing. |
657 | 0 | pub fn is_sorted<T: Borrow<usize>>(sequence: impl IntoIterator<Item = T>) -> bool { |
658 | 0 | // TODO: Remove this function when `is_sorted` graduates from Rust nightly. |
659 | 0 | let mut previous = 0; |
660 | 0 | for item in sequence.into_iter() { |
661 | 0 | let current = *item.borrow(); |
662 | 0 | if current < previous { |
663 | 0 | return false; |
664 | 0 | } |
665 | 0 | previous = current; |
666 | | } |
667 | 0 | true |
668 | 0 | } |
669 | | |
670 | | /// Find indices of each element in `targets` inside `items`. If one of the |
671 | | /// elements is absent in `items`, returns an error. |
672 | 0 | pub fn find_indices<T: PartialEq, S: Borrow<T>>( |
673 | 0 | items: &[T], |
674 | 0 | targets: impl IntoIterator<Item = S>, |
675 | 0 | ) -> Result<Vec<usize>> { |
676 | 0 | targets |
677 | 0 | .into_iter() |
678 | 0 | .map(|target| items.iter().position(|e| target.borrow().eq(e))) |
679 | 0 | .collect::<Option<_>>() |
680 | 0 | .ok_or_else(|| DataFusionError::Execution("Target not found".to_string())) |
681 | 0 | } |
682 | | |
683 | | /// Transposes the given vector of vectors. |
684 | 0 | pub fn transpose<T>(original: Vec<Vec<T>>) -> Vec<Vec<T>> { |
685 | 0 | match original.as_slice() { |
686 | 0 | [] => vec![], |
687 | 0 | [first, ..] => { |
688 | 0 | let mut result = (0..first.len()).map(|_| vec![]).collect::<Vec<_>>(); |
689 | 0 | for row in original { |
690 | 0 | for (item, transposed_row) in row.into_iter().zip(&mut result) { |
691 | 0 | transposed_row.push(item); |
692 | 0 | } |
693 | | } |
694 | 0 | result |
695 | | } |
696 | | } |
697 | 0 | } |
698 | | |
699 | | /// Computes the `skip` and `fetch` parameters of a single limit that would be |
700 | | /// equivalent to two consecutive limits with the given `skip`/`fetch` parameters. |
701 | | /// |
702 | | /// There are multiple cases to consider: |
703 | | /// |
704 | | /// # Case 0: Parent and child are disjoint (`child_fetch <= skip`). |
705 | | /// |
706 | | /// ```text |
707 | | /// Before merging: |
708 | | /// |........skip........|---fetch-->| Parent limit |
709 | | /// |...child_skip...|---child_fetch-->| Child limit |
710 | | /// ``` |
711 | | /// |
712 | | /// After merging: |
713 | | /// ```text |
714 | | /// |.........(child_skip + skip).........| |
715 | | /// ``` |
716 | | /// |
717 | | /// # Case 1: Parent is beyond child's range (`skip < child_fetch <= skip + fetch`). |
718 | | /// |
719 | | /// Before merging: |
720 | | /// ```text |
721 | | /// |...skip...|------------fetch------------>| Parent limit |
722 | | /// |...child_skip...|-------------child_fetch------------>| Child limit |
723 | | /// ``` |
724 | | /// |
725 | | /// After merging: |
726 | | /// ```text |
727 | | /// |....(child_skip + skip)....|---(child_fetch - skip)-->| |
728 | | /// ``` |
729 | | /// |
730 | | /// # Case 2: Parent is within child's range (`skip + fetch < child_fetch`). |
731 | | /// |
732 | | /// Before merging: |
733 | | /// ```text |
734 | | /// |...skip...|---fetch-->| Parent limit |
735 | | /// |...child_skip...|-------------child_fetch------------>| Child limit |
736 | | /// ``` |
737 | | /// |
738 | | /// After merging: |
739 | | /// ```text |
740 | | /// |....(child_skip + skip)....|---fetch-->| |
741 | | /// ``` |
742 | 0 | pub fn combine_limit( |
743 | 0 | parent_skip: usize, |
744 | 0 | parent_fetch: Option<usize>, |
745 | 0 | child_skip: usize, |
746 | 0 | child_fetch: Option<usize>, |
747 | 0 | ) -> (usize, Option<usize>) { |
748 | 0 | let combined_skip = child_skip.saturating_add(parent_skip); |
749 | | |
750 | 0 | let combined_fetch = match (parent_fetch, child_fetch) { |
751 | 0 | (Some(parent_fetch), Some(child_fetch)) => { |
752 | 0 | Some(min(parent_fetch, child_fetch.saturating_sub(parent_skip))) |
753 | | } |
754 | 0 | (Some(parent_fetch), None) => Some(parent_fetch), |
755 | 0 | (None, Some(child_fetch)) => Some(child_fetch.saturating_sub(parent_skip)), |
756 | 0 | (None, None) => None, |
757 | | }; |
758 | | |
759 | 0 | (combined_skip, combined_fetch) |
760 | 0 | } |
761 | | |
762 | | #[cfg(test)] |
763 | | mod tests { |
764 | | use crate::ScalarValue::Null; |
765 | | use arrow::array::Float64Array; |
766 | | |
767 | | use super::*; |
768 | | |
769 | | #[test] |
770 | | fn test_bisect_linear_left_and_right() -> Result<()> { |
771 | | let arrays: Vec<ArrayRef> = vec![ |
772 | | Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), |
773 | | Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), |
774 | | Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), |
775 | | Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), |
776 | | ]; |
777 | | let search_tuple: Vec<ScalarValue> = vec![ |
778 | | ScalarValue::Float64(Some(8.0)), |
779 | | ScalarValue::Float64(Some(3.0)), |
780 | | ScalarValue::Float64(Some(8.0)), |
781 | | ScalarValue::Float64(Some(8.0)), |
782 | | ]; |
783 | | let ords = [ |
784 | | SortOptions { |
785 | | descending: false, |
786 | | nulls_first: true, |
787 | | }, |
788 | | SortOptions { |
789 | | descending: false, |
790 | | nulls_first: true, |
791 | | }, |
792 | | SortOptions { |
793 | | descending: false, |
794 | | nulls_first: true, |
795 | | }, |
796 | | SortOptions { |
797 | | descending: true, |
798 | | nulls_first: true, |
799 | | }, |
800 | | ]; |
801 | | let res = bisect::<true>(&arrays, &search_tuple, &ords)?; |
802 | | assert_eq!(res, 2); |
803 | | let res = bisect::<false>(&arrays, &search_tuple, &ords)?; |
804 | | assert_eq!(res, 3); |
805 | | let res = linear_search::<true>(&arrays, &search_tuple, &ords)?; |
806 | | assert_eq!(res, 2); |
807 | | let res = linear_search::<false>(&arrays, &search_tuple, &ords)?; |
808 | | assert_eq!(res, 3); |
809 | | Ok(()) |
810 | | } |
811 | | |
812 | | #[test] |
813 | | fn vector_ord() { |
814 | | assert!(vec![1, 0, 0, 0, 0, 0, 0, 1] < vec![1, 0, 0, 0, 0, 0, 0, 2]); |
815 | | assert!(vec![1, 0, 0, 0, 0, 0, 1, 1] > vec![1, 0, 0, 0, 0, 0, 0, 2]); |
816 | | assert!( |
817 | | vec![ |
818 | | ScalarValue::Int32(Some(2)), |
819 | | Null, |
820 | | ScalarValue::Int32(Some(0)), |
821 | | ] < vec![ |
822 | | ScalarValue::Int32(Some(2)), |
823 | | Null, |
824 | | ScalarValue::Int32(Some(1)), |
825 | | ] |
826 | | ); |
827 | | assert!( |
828 | | vec![ |
829 | | ScalarValue::Int32(Some(2)), |
830 | | ScalarValue::Int32(None), |
831 | | ScalarValue::Int32(Some(0)), |
832 | | ] < vec![ |
833 | | ScalarValue::Int32(Some(2)), |
834 | | ScalarValue::Int32(None), |
835 | | ScalarValue::Int32(Some(1)), |
836 | | ] |
837 | | ); |
838 | | } |
839 | | |
840 | | #[test] |
841 | | fn ord_same_type() { |
842 | | assert!((ScalarValue::Int32(Some(2)) < ScalarValue::Int32(Some(3)))); |
843 | | } |
844 | | |
845 | | #[test] |
846 | | fn test_bisect_linear_left_and_right_diff_sort() -> Result<()> { |
847 | | // Descending, left |
848 | | let arrays: Vec<ArrayRef> = |
849 | | vec![Arc::new(Float64Array::from(vec![4.0, 3.0, 2.0, 1.0, 0.0]))]; |
850 | | let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(4.0))]; |
851 | | let ords = [SortOptions { |
852 | | descending: true, |
853 | | nulls_first: true, |
854 | | }]; |
855 | | let res = bisect::<true>(&arrays, &search_tuple, &ords)?; |
856 | | assert_eq!(res, 0); |
857 | | let res = linear_search::<true>(&arrays, &search_tuple, &ords)?; |
858 | | assert_eq!(res, 0); |
859 | | |
860 | | // Descending, right |
861 | | let arrays: Vec<ArrayRef> = |
862 | | vec![Arc::new(Float64Array::from(vec![4.0, 3.0, 2.0, 1.0, 0.0]))]; |
863 | | let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(4.0))]; |
864 | | let ords = [SortOptions { |
865 | | descending: true, |
866 | | nulls_first: true, |
867 | | }]; |
868 | | let res = bisect::<false>(&arrays, &search_tuple, &ords)?; |
869 | | assert_eq!(res, 1); |
870 | | let res = linear_search::<false>(&arrays, &search_tuple, &ords)?; |
871 | | assert_eq!(res, 1); |
872 | | |
873 | | // Ascending, left |
874 | | let arrays: Vec<ArrayRef> = |
875 | | vec![Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]))]; |
876 | | let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(7.0))]; |
877 | | let ords = [SortOptions { |
878 | | descending: false, |
879 | | nulls_first: true, |
880 | | }]; |
881 | | let res = bisect::<true>(&arrays, &search_tuple, &ords)?; |
882 | | assert_eq!(res, 1); |
883 | | let res = linear_search::<true>(&arrays, &search_tuple, &ords)?; |
884 | | assert_eq!(res, 1); |
885 | | |
886 | | // Ascending, right |
887 | | let arrays: Vec<ArrayRef> = |
888 | | vec![Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.]))]; |
889 | | let search_tuple: Vec<ScalarValue> = vec![ScalarValue::Float64(Some(7.0))]; |
890 | | let ords = [SortOptions { |
891 | | descending: false, |
892 | | nulls_first: true, |
893 | | }]; |
894 | | let res = bisect::<false>(&arrays, &search_tuple, &ords)?; |
895 | | assert_eq!(res, 2); |
896 | | let res = linear_search::<false>(&arrays, &search_tuple, &ords)?; |
897 | | assert_eq!(res, 2); |
898 | | |
899 | | let arrays: Vec<ArrayRef> = vec![ |
900 | | Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 8.0, 9., 10.])), |
901 | | Arc::new(Float64Array::from(vec![10.0, 9.0, 8.0, 7.5, 7., 6.])), |
902 | | ]; |
903 | | let search_tuple: Vec<ScalarValue> = vec![ |
904 | | ScalarValue::Float64(Some(8.0)), |
905 | | ScalarValue::Float64(Some(8.0)), |
906 | | ]; |
907 | | let ords = [ |
908 | | SortOptions { |
909 | | descending: false, |
910 | | nulls_first: true, |
911 | | }, |
912 | | SortOptions { |
913 | | descending: true, |
914 | | nulls_first: true, |
915 | | }, |
916 | | ]; |
917 | | let res = bisect::<false>(&arrays, &search_tuple, &ords)?; |
918 | | assert_eq!(res, 3); |
919 | | let res = linear_search::<false>(&arrays, &search_tuple, &ords)?; |
920 | | assert_eq!(res, 3); |
921 | | |
922 | | let res = bisect::<true>(&arrays, &search_tuple, &ords)?; |
923 | | assert_eq!(res, 2); |
924 | | let res = linear_search::<true>(&arrays, &search_tuple, &ords)?; |
925 | | assert_eq!(res, 2); |
926 | | Ok(()) |
927 | | } |
928 | | |
929 | | #[test] |
930 | | fn test_evaluate_partition_ranges() -> Result<()> { |
931 | | let arrays: Vec<ArrayRef> = vec![ |
932 | | Arc::new(Float64Array::from(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0])), |
933 | | Arc::new(Float64Array::from(vec![4.0, 4.0, 3.0, 2.0, 1.0, 1.0])), |
934 | | ]; |
935 | | let n_row = arrays[0].len(); |
936 | | let options: Vec<SortOptions> = vec![ |
937 | | SortOptions { |
938 | | descending: false, |
939 | | nulls_first: false, |
940 | | }, |
941 | | SortOptions { |
942 | | descending: true, |
943 | | nulls_first: false, |
944 | | }, |
945 | | ]; |
946 | | let sort_columns = arrays |
947 | | .into_iter() |
948 | | .zip(options) |
949 | | .map(|(values, options)| SortColumn { |
950 | | values, |
951 | | options: Some(options), |
952 | | }) |
953 | | .collect::<Vec<_>>(); |
954 | | let ranges = evaluate_partition_ranges(n_row, &sort_columns)?; |
955 | | assert_eq!(ranges.len(), 4); |
956 | | assert_eq!(ranges[0], Range { start: 0, end: 2 }); |
957 | | assert_eq!(ranges[1], Range { start: 2, end: 3 }); |
958 | | assert_eq!(ranges[2], Range { start: 3, end: 4 }); |
959 | | assert_eq!(ranges[3], Range { start: 4, end: 6 }); |
960 | | Ok(()) |
961 | | } |
962 | | |
963 | | #[test] |
964 | | fn test_quote_identifier() -> Result<()> { |
965 | | let cases = vec![ |
966 | | ("foo", r#"foo"#), |
967 | | ("_foo", r#"_foo"#), |
968 | | ("foo_bar", r#"foo_bar"#), |
969 | | ("foo-bar", r#""foo-bar""#), |
970 | | // name itself has a period, needs to be quoted |
971 | | ("foo.bar", r#""foo.bar""#), |
972 | | ("Foo", r#""Foo""#), |
973 | | ("Foo.Bar", r#""Foo.Bar""#), |
974 | | // name starting with a number needs to be quoted |
975 | | ("test1", r#"test1"#), |
976 | | ("1test", r#""1test""#), |
977 | | ]; |
978 | | |
979 | | for (identifier, quoted_identifier) in cases { |
980 | | println!("input: \n{identifier}\nquoted_identifier:\n{quoted_identifier}"); |
981 | | |
982 | | assert_eq!(quote_identifier(identifier), quoted_identifier); |
983 | | |
984 | | // When parsing the quoted identifier, it should be a |
985 | | // a single identifier without normalization, and not in multiple parts |
986 | | let quote_style = if quoted_identifier.starts_with('"') { |
987 | | Some('"') |
988 | | } else { |
989 | | None |
990 | | }; |
991 | | |
992 | | let expected_parsed = vec![Ident { |
993 | | value: identifier.to_string(), |
994 | | quote_style, |
995 | | }]; |
996 | | |
997 | | assert_eq!( |
998 | | parse_identifiers(quoted_identifier).unwrap(), |
999 | | expected_parsed |
1000 | | ); |
1001 | | } |
1002 | | |
1003 | | Ok(()) |
1004 | | } |
1005 | | |
1006 | | #[test] |
1007 | | fn test_take_arrays() -> Result<()> { |
1008 | | let arrays: Vec<ArrayRef> = vec![ |
1009 | | Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), |
1010 | | Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), |
1011 | | Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), |
1012 | | Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), |
1013 | | ]; |
1014 | | |
1015 | | let row_indices_vec: Vec<Vec<u32>> = vec![ |
1016 | | // Get rows 0 and 1 |
1017 | | vec![0, 1], |
1018 | | // Get rows 0 and 1 |
1019 | | vec![0, 2], |
1020 | | // Get rows 1 and 3 |
1021 | | vec![1, 3], |
1022 | | // Get rows 2 and 4 |
1023 | | vec![2, 4], |
1024 | | ]; |
1025 | | for row_indices in row_indices_vec { |
1026 | | let indices: PrimitiveArray<UInt32Type> = |
1027 | | PrimitiveArray::from_iter_values(row_indices.iter().cloned()); |
1028 | | let chunk = take_arrays(&arrays, &indices)?; |
1029 | | for (arr_orig, arr_chunk) in arrays.iter().zip(&chunk) { |
1030 | | for (idx, orig_idx) in row_indices.iter().enumerate() { |
1031 | | let res1 = ScalarValue::try_from_array(arr_orig, *orig_idx as usize)?; |
1032 | | let res2 = ScalarValue::try_from_array(arr_chunk, idx)?; |
1033 | | assert_eq!(res1, res2); |
1034 | | } |
1035 | | } |
1036 | | } |
1037 | | Ok(()) |
1038 | | } |
1039 | | |
1040 | | #[test] |
1041 | | fn test_get_at_indices() -> Result<()> { |
1042 | | let in_vec = vec![1, 2, 3, 4, 5, 6, 7]; |
1043 | | assert_eq!(get_at_indices(&in_vec, [0, 2])?, vec![1, 3]); |
1044 | | assert_eq!(get_at_indices(&in_vec, [4, 2])?, vec![5, 3]); |
1045 | | // 7 is outside the range |
1046 | | assert!(get_at_indices(&in_vec, [7]).is_err()); |
1047 | | Ok(()) |
1048 | | } |
1049 | | |
1050 | | #[test] |
1051 | | fn test_longest_consecutive_prefix() { |
1052 | | assert_eq!(longest_consecutive_prefix([0, 3, 4]), 1); |
1053 | | assert_eq!(longest_consecutive_prefix([0, 1, 3, 4]), 2); |
1054 | | assert_eq!(longest_consecutive_prefix([0, 1, 2, 3, 4]), 5); |
1055 | | assert_eq!(longest_consecutive_prefix([1, 2, 3, 4]), 0); |
1056 | | } |
1057 | | |
1058 | | #[test] |
1059 | | fn test_merge_and_order_indices() { |
1060 | | assert_eq!( |
1061 | | merge_and_order_indices([0, 3, 4], [1, 3, 5]), |
1062 | | vec![0, 1, 3, 4, 5] |
1063 | | ); |
1064 | | // Result should be ordered, even if inputs are not |
1065 | | assert_eq!( |
1066 | | merge_and_order_indices([3, 0, 4], [5, 1, 3]), |
1067 | | vec![0, 1, 3, 4, 5] |
1068 | | ); |
1069 | | } |
1070 | | |
1071 | | #[test] |
1072 | | fn test_set_difference() { |
1073 | | assert_eq!(set_difference([0, 3, 4], [1, 2]), vec![0, 3, 4]); |
1074 | | assert_eq!(set_difference([0, 3, 4], [1, 2, 4]), vec![0, 3]); |
1075 | | // return value should have same ordering with the in1 |
1076 | | assert_eq!(set_difference([3, 4, 0], [1, 2, 4]), vec![3, 0]); |
1077 | | assert_eq!(set_difference([0, 3, 4], [4, 1, 2]), vec![0, 3]); |
1078 | | assert_eq!(set_difference([3, 4, 0], [4, 1, 2]), vec![3, 0]); |
1079 | | } |
1080 | | |
1081 | | #[test] |
1082 | | fn test_is_sorted() { |
1083 | | assert!(is_sorted::<usize>([])); |
1084 | | assert!(is_sorted([0])); |
1085 | | assert!(is_sorted([0, 3, 4])); |
1086 | | assert!(is_sorted([0, 1, 2])); |
1087 | | assert!(is_sorted([0, 1, 4])); |
1088 | | assert!(is_sorted([0usize; 0])); |
1089 | | assert!(is_sorted([1, 2])); |
1090 | | assert!(!is_sorted([3, 2])); |
1091 | | } |
1092 | | |
1093 | | #[test] |
1094 | | fn test_find_indices() -> Result<()> { |
1095 | | assert_eq!(find_indices(&[0, 3, 4], [0, 3, 4])?, vec![0, 1, 2]); |
1096 | | assert_eq!(find_indices(&[0, 3, 4], [0, 4, 3])?, vec![0, 2, 1]); |
1097 | | assert_eq!(find_indices(&[3, 0, 4], [0, 3])?, vec![1, 0]); |
1098 | | assert!(find_indices(&[0, 3], [0, 3, 4]).is_err()); |
1099 | | assert!(find_indices(&[0, 3, 4], [0, 2]).is_err()); |
1100 | | Ok(()) |
1101 | | } |
1102 | | |
1103 | | #[test] |
1104 | | fn test_transpose() -> Result<()> { |
1105 | | let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]]; |
1106 | | let transposed = transpose(in_data); |
1107 | | let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]]; |
1108 | | assert_eq!(expected, transposed); |
1109 | | Ok(()) |
1110 | | } |
1111 | | } |