/Users/andrewlamb/Software/datafusion/datafusion/physical-expr-common/src/binary_view_map.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`ArrowBytesViewMap`] and [`ArrowBytesViewSet`] for storing maps/sets of values from |
19 | | //! `StringViewArray`/`BinaryViewArray`. |
20 | | //! Much of the code is from `binary_map.rs`, but with simpler implementation because we directly use the |
21 | | //! [`GenericByteViewBuilder`]. |
22 | | use ahash::RandomState; |
23 | | use arrow::array::cast::AsArray; |
24 | | use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; |
25 | | use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; |
26 | | use datafusion_common::hash_utils::create_hashes; |
27 | | use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; |
28 | | use std::fmt::Debug; |
29 | | use std::sync::Arc; |
30 | | |
31 | | use crate::binary_map::OutputType; |
32 | | |
33 | | /// HashSet optimized for storing string or binary values that can produce that |
34 | | /// the final set as a `GenericBinaryViewArray` with minimal copies. |
35 | | #[derive(Debug)] |
36 | | pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>); |
37 | | |
38 | | impl ArrowBytesViewSet { |
39 | 0 | pub fn new(output_type: OutputType) -> Self { |
40 | 0 | Self(ArrowBytesViewMap::new(output_type)) |
41 | 0 | } |
42 | | |
43 | | /// Inserts each value from `values` into the set |
44 | 0 | pub fn insert(&mut self, values: &ArrayRef) { |
45 | 0 | fn make_payload_fn(_value: Option<&[u8]>) {} |
46 | 0 | fn observe_payload_fn(_payload: ()) {} |
47 | 0 | self.0 |
48 | 0 | .insert_if_new(values, make_payload_fn, observe_payload_fn); |
49 | 0 | } |
50 | | |
51 | | /// Return the contents of this map and replace it with a new empty map with |
52 | | /// the same output type |
53 | 0 | pub fn take(&mut self) -> Self { |
54 | 0 | let mut new_self = Self::new(self.0.output_type); |
55 | 0 | std::mem::swap(self, &mut new_self); |
56 | 0 | new_self |
57 | 0 | } |
58 | | |
59 | | /// Converts this set into a `StringViewArray` or `BinaryViewArray` |
60 | | /// containing each distinct value that was interned. |
61 | | /// This is done without copying the values. |
62 | 0 | pub fn into_state(self) -> ArrayRef { |
63 | 0 | self.0.into_state() |
64 | 0 | } |
65 | | |
66 | | /// Returns the total number of distinct values (including nulls) seen so far |
67 | 0 | pub fn len(&self) -> usize { |
68 | 0 | self.0.len() |
69 | 0 | } |
70 | | |
71 | 0 | pub fn is_empty(&self) -> bool { |
72 | 0 | self.0.is_empty() |
73 | 0 | } |
74 | | |
75 | | /// returns the total number of distinct values (not including nulls) seen so far |
76 | 0 | pub fn non_null_len(&self) -> usize { |
77 | 0 | self.0.non_null_len() |
78 | 0 | } |
79 | | |
80 | | /// Return the total size, in bytes, of memory used to store the data in |
81 | | /// this set, not including `self` |
82 | 0 | pub fn size(&self) -> usize { |
83 | 0 | self.0.size() |
84 | 0 | } |
85 | | } |
86 | | |
87 | | /// Optimized map for storing Arrow "byte view" types (`StringView`, `BinaryView`) |
88 | | /// values that can produce the set of keys on |
89 | | /// output as `GenericBinaryViewArray` without copies. |
90 | | /// |
91 | | /// Equivalent to `HashSet<String, V>` but with better performance for arrow |
92 | | /// data. |
93 | | /// |
94 | | /// # Generic Arguments |
95 | | /// |
96 | | /// * `V`: payload type |
97 | | /// |
98 | | /// # Description |
99 | | /// |
100 | | /// This is a specialized HashMap with the following properties: |
101 | | /// |
102 | | /// 1. Optimized for storing and emitting Arrow byte types (e.g. |
103 | | /// `StringViewArray` / `BinaryViewArray`) very efficiently by minimizing copying of |
104 | | /// the string values themselves, both when inserting and when emitting the |
105 | | /// final array. |
106 | | /// |
107 | | /// 2. Retains the insertion order of entries in the final array. The values are |
108 | | /// in the same order as they were inserted. |
109 | | /// |
110 | | /// Note this structure can be used as a `HashSet` by specifying the value type |
111 | | /// as `()`, as is done by [`ArrowBytesViewSet`]. |
112 | | /// |
113 | | /// This map is used by the special `COUNT DISTINCT` aggregate function to |
114 | | /// store the distinct values, and by the `GROUP BY` operator to store |
115 | | /// group values when they are a single string array. |
116 | | |
117 | | pub struct ArrowBytesViewMap<V> |
118 | | where |
119 | | V: Debug + PartialEq + Eq + Clone + Copy + Default, |
120 | | { |
121 | | /// Should the output be StringView or BinaryView? |
122 | | output_type: OutputType, |
123 | | /// Underlying hash set for each distinct value |
124 | | map: hashbrown::raw::RawTable<Entry<V>>, |
125 | | /// Total size of the map in bytes |
126 | | map_size: usize, |
127 | | |
128 | | /// Builder for output array |
129 | | builder: GenericByteViewBuilder<BinaryViewType>, |
130 | | /// random state used to generate hashes |
131 | | random_state: RandomState, |
132 | | /// buffer that stores hash values (reused across batches to save allocations) |
133 | | hashes_buffer: Vec<u64>, |
134 | | /// `(payload, null_index)` for the 'null' value, if any |
135 | | /// NOTE null_index is the logical index in the final array, not the index |
136 | | /// in the buffer |
137 | | null: Option<(V, usize)>, |
138 | | } |
139 | | |
140 | | /// The size, in number of entries, of the initial hash table |
141 | | const INITIAL_MAP_CAPACITY: usize = 512; |
142 | | |
143 | | impl<V> ArrowBytesViewMap<V> |
144 | | where |
145 | | V: Debug + PartialEq + Eq + Clone + Copy + Default, |
146 | | { |
147 | 0 | pub fn new(output_type: OutputType) -> Self { |
148 | 0 | Self { |
149 | 0 | output_type, |
150 | 0 | map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), |
151 | 0 | map_size: 0, |
152 | 0 | builder: GenericByteViewBuilder::new(), |
153 | 0 | random_state: RandomState::new(), |
154 | 0 | hashes_buffer: vec![], |
155 | 0 | null: None, |
156 | 0 | } |
157 | 0 | } |
158 | | |
159 | | /// Return the contents of this map and replace it with a new empty map with |
160 | | /// the same output type |
161 | 0 | pub fn take(&mut self) -> Self { |
162 | 0 | let mut new_self = Self::new(self.output_type); |
163 | 0 | std::mem::swap(self, &mut new_self); |
164 | 0 | new_self |
165 | 0 | } |
166 | | |
167 | | /// Inserts each value from `values` into the map, invoking `payload_fn` for |
168 | | /// each value if *not* already present, deferring the allocation of the |
169 | | /// payload until it is needed. |
170 | | /// |
171 | | /// Note that this is different than a normal map that would replace the |
172 | | /// existing entry |
173 | | /// |
174 | | /// # Arguments: |
175 | | /// |
176 | | /// `values`: array whose values are inserted |
177 | | /// |
178 | | /// `make_payload_fn`: invoked for each value that is not already present |
179 | | /// to create the payload, in order of the values in `values` |
180 | | /// |
181 | | /// `observe_payload_fn`: invoked once, for each value in `values`, that was |
182 | | /// already present in the map, with corresponding payload value. |
183 | | /// |
184 | | /// # Returns |
185 | | /// |
186 | | /// The payload value for the entry, either the existing value or |
187 | | /// the newly inserted value |
188 | | /// |
189 | | /// # Safety: |
190 | | /// |
191 | | /// Note that `make_payload_fn` and `observe_payload_fn` are only invoked |
192 | | /// with valid values from `values`, not for the `NULL` value. |
193 | 0 | pub fn insert_if_new<MP, OP>( |
194 | 0 | &mut self, |
195 | 0 | values: &ArrayRef, |
196 | 0 | make_payload_fn: MP, |
197 | 0 | observe_payload_fn: OP, |
198 | 0 | ) where |
199 | 0 | MP: FnMut(Option<&[u8]>) -> V, |
200 | 0 | OP: FnMut(V), |
201 | 0 | { |
202 | 0 | // Sanity check array type |
203 | 0 | match self.output_type { |
204 | | OutputType::BinaryView => { |
205 | 0 | assert!(matches!(values.data_type(), DataType::BinaryView)); |
206 | 0 | self.insert_if_new_inner::<MP, OP, BinaryViewType>( |
207 | 0 | values, |
208 | 0 | make_payload_fn, |
209 | 0 | observe_payload_fn, |
210 | 0 | ) |
211 | | } |
212 | | OutputType::Utf8View => { |
213 | 0 | assert!(matches!(values.data_type(), DataType::Utf8View)); |
214 | 0 | self.insert_if_new_inner::<MP, OP, StringViewType>( |
215 | 0 | values, |
216 | 0 | make_payload_fn, |
217 | 0 | observe_payload_fn, |
218 | 0 | ) |
219 | | } |
220 | 0 | _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), |
221 | | }; |
222 | 0 | } |
223 | | |
224 | | /// Generic version of [`Self::insert_if_new`] that handles `ByteViewType` |
225 | | /// (both StringView and BinaryView) |
226 | | /// |
227 | | /// Note this is the only function that is generic on [`ByteViewType`], which |
228 | | /// avoids having to template the entire structure, making the code |
229 | | /// simpler and understand and reducing code bloat due to duplication. |
230 | | /// |
231 | | /// See comments on `insert_if_new` for more details |
232 | 0 | fn insert_if_new_inner<MP, OP, B>( |
233 | 0 | &mut self, |
234 | 0 | values: &ArrayRef, |
235 | 0 | mut make_payload_fn: MP, |
236 | 0 | mut observe_payload_fn: OP, |
237 | 0 | ) where |
238 | 0 | MP: FnMut(Option<&[u8]>) -> V, |
239 | 0 | OP: FnMut(V), |
240 | 0 | B: ByteViewType, |
241 | 0 | { |
242 | 0 | // step 1: compute hashes |
243 | 0 | let batch_hashes = &mut self.hashes_buffer; |
244 | 0 | batch_hashes.clear(); |
245 | 0 | batch_hashes.resize(values.len(), 0); |
246 | 0 | create_hashes(&[values.clone()], &self.random_state, batch_hashes) |
247 | 0 | // hash is supported for all types and create_hashes only |
248 | 0 | // returns errors for unsupported types |
249 | 0 | .unwrap(); |
250 | 0 |
|
251 | 0 | // step 2: insert each value into the set, if not already present |
252 | 0 | let values = values.as_byte_view::<B>(); |
253 | 0 |
|
254 | 0 | // Ensure lengths are equivalent |
255 | 0 | assert_eq!(values.len(), batch_hashes.len()); |
256 | | |
257 | 0 | for (value, &hash) in values.iter().zip(batch_hashes.iter()) { |
258 | | // handle null value |
259 | 0 | let Some(value) = value else { |
260 | 0 | let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { |
261 | 0 | payload |
262 | | } else { |
263 | 0 | let payload = make_payload_fn(None); |
264 | 0 | let null_index = self.builder.len(); |
265 | 0 | self.builder.append_null(); |
266 | 0 | self.null = Some((payload, null_index)); |
267 | 0 | payload |
268 | | }; |
269 | 0 | observe_payload_fn(payload); |
270 | 0 | continue; |
271 | | }; |
272 | | |
273 | | // get the value as bytes |
274 | 0 | let value: &[u8] = value.as_ref(); |
275 | 0 |
|
276 | 0 | let entry = self.map.get_mut(hash, |header| { |
277 | 0 | let v = self.builder.get_value(header.view_idx); |
278 | 0 |
|
279 | 0 | if v.len() != value.len() { |
280 | 0 | return false; |
281 | 0 | } |
282 | 0 |
|
283 | 0 | v == value |
284 | 0 | }); |
285 | | |
286 | 0 | let payload = if let Some(entry) = entry { |
287 | 0 | entry.payload |
288 | | } else { |
289 | | // no existing value, make a new one. |
290 | 0 | let payload = make_payload_fn(Some(value)); |
291 | 0 |
|
292 | 0 | let inner_view_idx = self.builder.len(); |
293 | 0 | let new_header = Entry { |
294 | 0 | view_idx: inner_view_idx, |
295 | 0 | hash, |
296 | 0 | payload, |
297 | 0 | }; |
298 | 0 |
|
299 | 0 | self.builder.append_value(value); |
300 | 0 |
|
301 | 0 | self.map |
302 | 0 | .insert_accounted(new_header, |h| h.hash, &mut self.map_size); |
303 | 0 | payload |
304 | | }; |
305 | 0 | observe_payload_fn(payload); |
306 | | } |
307 | 0 | } |
308 | | |
309 | | /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, |
310 | | /// containing each distinct value |
311 | | /// that was inserted. This is done without copying the values. |
312 | | /// |
313 | | /// The values are guaranteed to be returned in the same order in which |
314 | | /// they were first seen. |
315 | 0 | pub fn into_state(self) -> ArrayRef { |
316 | 0 | let mut builder = self.builder; |
317 | 0 | match self.output_type { |
318 | | OutputType::BinaryView => { |
319 | 0 | let array = builder.finish(); |
320 | 0 |
|
321 | 0 | Arc::new(array) |
322 | | } |
323 | | OutputType::Utf8View => { |
324 | | // SAFETY: |
325 | | // we asserted the input arrays were all the correct type and |
326 | | // thus since all the values that went in were valid (e.g. utf8) |
327 | | // so are all the values that come out |
328 | 0 | let array = builder.finish(); |
329 | 0 | let array = unsafe { array.to_string_view_unchecked() }; |
330 | 0 | Arc::new(array) |
331 | | } |
332 | | _ => { |
333 | 0 | unreachable!("Utf8/Binary should use `ArrowBytesMap`") |
334 | | } |
335 | | } |
336 | 0 | } |
337 | | |
338 | | /// Total number of entries (including null, if present) |
339 | 0 | pub fn len(&self) -> usize { |
340 | 0 | self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) |
341 | 0 | } |
342 | | |
343 | | /// Is the set empty? |
344 | 0 | pub fn is_empty(&self) -> bool { |
345 | 0 | self.map.is_empty() && self.null.is_none() |
346 | 0 | } |
347 | | |
348 | | /// Number of non null entries |
349 | 0 | pub fn non_null_len(&self) -> usize { |
350 | 0 | self.map.len() |
351 | 0 | } |
352 | | |
353 | | /// Return the total size, in bytes, of memory used to store the data in |
354 | | /// this set, not including `self` |
355 | 0 | pub fn size(&self) -> usize { |
356 | 0 | self.map_size |
357 | 0 | + self.builder.allocated_size() |
358 | 0 | + self.hashes_buffer.allocated_size() |
359 | 0 | } |
360 | | } |
361 | | |
362 | | impl<V> Debug for ArrowBytesViewMap<V> |
363 | | where |
364 | | V: Debug + PartialEq + Eq + Clone + Copy + Default, |
365 | | { |
366 | 0 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
367 | 0 | f.debug_struct("ArrowBytesMap") |
368 | 0 | .field("map", &"<map>") |
369 | 0 | .field("map_size", &self.map_size) |
370 | 0 | .field("view_builder", &self.builder) |
371 | 0 | .field("random_state", &self.random_state) |
372 | 0 | .field("hashes_buffer", &self.hashes_buffer) |
373 | 0 | .finish() |
374 | 0 | } |
375 | | } |
376 | | |
377 | | /// Entry in the hash table -- see [`ArrowBytesViewMap`] for more details |
378 | | #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] |
379 | | struct Entry<V> |
380 | | where |
381 | | V: Debug + PartialEq + Eq + Clone + Copy + Default, |
382 | | { |
383 | | /// The idx into the views array |
384 | | view_idx: usize, |
385 | | |
386 | | hash: u64, |
387 | | |
388 | | /// value stored by the entry |
389 | | payload: V, |
390 | | } |
391 | | |
392 | | #[cfg(test)] |
393 | | mod tests { |
394 | | use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; |
395 | | use hashbrown::HashMap; |
396 | | |
397 | | use super::*; |
398 | | |
399 | | // asserts that the set contains the expected strings, in the same order |
400 | | fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) { |
401 | | let strings = set.into_state(); |
402 | | let strings = strings.as_string_view(); |
403 | | let state = strings.into_iter().collect::<Vec<_>>(); |
404 | | assert_eq!(state, expected); |
405 | | } |
406 | | |
407 | | #[test] |
408 | | fn string_view_set_empty() { |
409 | | let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); |
410 | | let array: ArrayRef = Arc::new(StringViewArray::new_null(0)); |
411 | | set.insert(&array); |
412 | | assert_eq!(set.len(), 0); |
413 | | assert_eq!(set.non_null_len(), 0); |
414 | | assert_set(set, &[]); |
415 | | } |
416 | | |
417 | | #[test] |
418 | | fn string_view_set_one_null() { |
419 | | let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); |
420 | | let array: ArrayRef = Arc::new(StringViewArray::new_null(1)); |
421 | | set.insert(&array); |
422 | | assert_eq!(set.len(), 1); |
423 | | assert_eq!(set.non_null_len(), 0); |
424 | | assert_set(set, &[None]); |
425 | | } |
426 | | |
427 | | #[test] |
428 | | fn string_view_set_many_null() { |
429 | | let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); |
430 | | let array: ArrayRef = Arc::new(StringViewArray::new_null(11)); |
431 | | set.insert(&array); |
432 | | assert_eq!(set.len(), 1); |
433 | | assert_eq!(set.non_null_len(), 0); |
434 | | assert_set(set, &[None]); |
435 | | } |
436 | | |
437 | | #[test] |
438 | | fn test_string_view_set_basic() { |
439 | | // basic test for mixed small and large string values |
440 | | let values = GenericByteViewArray::from(vec![ |
441 | | Some("a"), |
442 | | Some("b"), |
443 | | Some("CXCCCCCCCCAABB"), // 14 bytes |
444 | | Some(""), |
445 | | Some("cbcxx"), // 5 bytes |
446 | | None, |
447 | | Some("AAAAAAAA"), // 8 bytes |
448 | | Some("BBBBBQBBBAAA"), // 12 bytes |
449 | | Some("a"), |
450 | | Some("cbcxx"), |
451 | | Some("b"), |
452 | | Some("cbcxx"), |
453 | | Some(""), |
454 | | None, |
455 | | Some("BBBBBQBBBAAA"), |
456 | | Some("BBBBBQBBBAAA"), |
457 | | Some("AAAAAAAA"), |
458 | | Some("CXCCCCCCCCAABB"), |
459 | | ]); |
460 | | |
461 | | let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); |
462 | | let array: ArrayRef = Arc::new(values); |
463 | | set.insert(&array); |
464 | | // values mut appear be in the order they were inserted |
465 | | assert_set( |
466 | | set, |
467 | | &[ |
468 | | Some("a"), |
469 | | Some("b"), |
470 | | Some("CXCCCCCCCCAABB"), |
471 | | Some(""), |
472 | | Some("cbcxx"), |
473 | | None, |
474 | | Some("AAAAAAAA"), |
475 | | Some("BBBBBQBBBAAA"), |
476 | | ], |
477 | | ); |
478 | | } |
479 | | |
480 | | #[test] |
481 | | fn test_string_set_non_utf8() { |
482 | | // basic test for mixed small and large string values |
483 | | let values = GenericByteViewArray::from(vec![ |
484 | | Some("a"), |
485 | | Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), |
486 | | Some("🔥"), |
487 | | Some("✨✨✨"), |
488 | | Some("foobarbaz"), |
489 | | Some("🔥"), |
490 | | Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), |
491 | | ]); |
492 | | |
493 | | let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); |
494 | | let array: ArrayRef = Arc::new(values); |
495 | | set.insert(&array); |
496 | | // strings mut appear be in the order they were inserted |
497 | | assert_set( |
498 | | set, |
499 | | &[ |
500 | | Some("a"), |
501 | | Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), |
502 | | Some("🔥"), |
503 | | Some("✨✨✨"), |
504 | | Some("foobarbaz"), |
505 | | ], |
506 | | ); |
507 | | } |
508 | | |
509 | | // Test use of binary output type |
510 | | #[test] |
511 | | fn test_binary_set() { |
512 | | let v: Vec<Option<&[u8]>> = vec![ |
513 | | Some(b"a"), |
514 | | Some(b"CXCCCCCCCCCCCCC"), |
515 | | None, |
516 | | Some(b"CXCCCCCCCCCCCCC"), |
517 | | ]; |
518 | | let values: ArrayRef = Arc::new(BinaryViewArray::from(v)); |
519 | | |
520 | | let expected: Vec<Option<&[u8]>> = |
521 | | vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None]; |
522 | | let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected)); |
523 | | |
524 | | let mut set = ArrowBytesViewSet::new(OutputType::BinaryView); |
525 | | set.insert(&values); |
526 | | assert_eq!(&set.into_state(), &expected); |
527 | | } |
528 | | |
529 | | // inserting strings into the set does not increase reported memory |
530 | | #[test] |
531 | | fn test_string_set_memory_usage() { |
532 | | let strings1 = StringViewArray::from(vec![ |
533 | | Some("a"), |
534 | | Some("b"), |
535 | | Some("CXCCCCCCCCCCC"), // 13 bytes |
536 | | Some("AAAAAAAA"), // 8 bytes |
537 | | Some("BBBBBQBBB"), // 9 bytes |
538 | | ]); |
539 | | let total_strings1_len = strings1 |
540 | | .iter() |
541 | | .map(|s| s.map(|s| s.len()).unwrap_or(0)) |
542 | | .sum::<usize>(); |
543 | | let values1: ArrayRef = Arc::new(StringViewArray::from(strings1)); |
544 | | |
545 | | // Much larger strings in strings2 |
546 | | let strings2 = StringViewArray::from(vec![ |
547 | | "FOO".repeat(1000), |
548 | | "BAR larger than 12 bytes.".repeat(100_000), |
549 | | "more unique.".repeat(1000), |
550 | | "more unique2.".repeat(1000), |
551 | | "FOO".repeat(3000), |
552 | | ]); |
553 | | let total_strings2_len = strings2 |
554 | | .iter() |
555 | | .map(|s| s.map(|s| s.len()).unwrap_or(0)) |
556 | | .sum::<usize>(); |
557 | | let values2: ArrayRef = Arc::new(StringViewArray::from(strings2)); |
558 | | |
559 | | let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); |
560 | | let size_empty = set.size(); |
561 | | |
562 | | set.insert(&values1); |
563 | | let size_after_values1 = set.size(); |
564 | | assert!(size_empty < size_after_values1); |
565 | | assert!( |
566 | | size_after_values1 > total_strings1_len, |
567 | | "expect {size_after_values1} to be more than {total_strings1_len}" |
568 | | ); |
569 | | assert!(size_after_values1 < total_strings1_len + total_strings2_len); |
570 | | |
571 | | // inserting the same strings should not affect the size |
572 | | set.insert(&values1); |
573 | | assert_eq!(set.size(), size_after_values1); |
574 | | assert_eq!(set.len(), 5); |
575 | | |
576 | | // inserting the large strings should increase the reported size |
577 | | set.insert(&values2); |
578 | | let size_after_values2 = set.size(); |
579 | | assert!(size_after_values2 > size_after_values1); |
580 | | |
581 | | assert_eq!(set.len(), 10); |
582 | | } |
583 | | |
584 | | #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] |
585 | | struct TestPayload { |
586 | | // store the string value to check against input |
587 | | index: usize, // store the index of the string (each new string gets the next sequential input) |
588 | | } |
589 | | |
590 | | /// Wraps an [`ArrowBytesViewMap`], validating its invariants |
591 | | struct TestMap { |
592 | | map: ArrowBytesViewMap<TestPayload>, |
593 | | // stores distinct strings seen, in order |
594 | | strings: Vec<Option<String>>, |
595 | | // map strings to index in strings |
596 | | indexes: HashMap<Option<String>, usize>, |
597 | | } |
598 | | |
599 | | impl TestMap { |
600 | | /// creates a map with TestPayloads for the given strings and then |
601 | | /// validates the payloads |
602 | | fn new() -> Self { |
603 | | Self { |
604 | | map: ArrowBytesViewMap::new(OutputType::Utf8View), |
605 | | strings: vec![], |
606 | | indexes: HashMap::new(), |
607 | | } |
608 | | } |
609 | | |
610 | | /// Inserts strings into the map |
611 | | fn insert(&mut self, strings: &[Option<&str>]) { |
612 | | let string_array = StringViewArray::from(strings.to_vec()); |
613 | | let arr: ArrayRef = Arc::new(string_array); |
614 | | |
615 | | let mut next_index = self.indexes.len(); |
616 | | let mut actual_new_strings = vec![]; |
617 | | let mut actual_seen_indexes = vec![]; |
618 | | // update self with new values, keeping track of newly added values |
619 | | for str in strings { |
620 | | let str = str.map(|s| s.to_string()); |
621 | | let index = self.indexes.get(&str).cloned().unwrap_or_else(|| { |
622 | | actual_new_strings.push(str.clone()); |
623 | | let index = self.strings.len(); |
624 | | self.strings.push(str.clone()); |
625 | | self.indexes.insert(str, index); |
626 | | index |
627 | | }); |
628 | | actual_seen_indexes.push(index); |
629 | | } |
630 | | |
631 | | // insert the values into the map, recording what we did |
632 | | let mut seen_new_strings = vec![]; |
633 | | let mut seen_indexes = vec![]; |
634 | | self.map.insert_if_new( |
635 | | &arr, |
636 | | |s| { |
637 | | let value = s |
638 | | .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string")); |
639 | | let index = next_index; |
640 | | next_index += 1; |
641 | | seen_new_strings.push(value); |
642 | | TestPayload { index } |
643 | | }, |
644 | | |payload| { |
645 | | seen_indexes.push(payload.index); |
646 | | }, |
647 | | ); |
648 | | |
649 | | assert_eq!(actual_seen_indexes, seen_indexes); |
650 | | assert_eq!(actual_new_strings, seen_new_strings); |
651 | | } |
652 | | |
653 | | /// Call `self.map.into_array()` validating that the strings are in the same |
654 | | /// order as they were inserted |
655 | | fn into_array(self) -> ArrayRef { |
656 | | let Self { |
657 | | map, |
658 | | strings, |
659 | | indexes: _, |
660 | | } = self; |
661 | | |
662 | | let arr = map.into_state(); |
663 | | let expected: ArrayRef = Arc::new(StringViewArray::from(strings)); |
664 | | assert_eq!(&arr, &expected); |
665 | | arr |
666 | | } |
667 | | } |
668 | | |
669 | | #[test] |
670 | | fn test_map() { |
671 | | let input = vec![ |
672 | | // Note mix of short/long strings |
673 | | Some("A"), |
674 | | Some("bcdefghijklmnop1234567"), |
675 | | Some("X"), |
676 | | Some("Y"), |
677 | | None, |
678 | | Some("qrstuvqxyzhjwya"), |
679 | | Some("✨🔥"), |
680 | | Some("🔥"), |
681 | | Some("🔥🔥🔥🔥🔥🔥"), |
682 | | ]; |
683 | | |
684 | | let mut test_map = TestMap::new(); |
685 | | test_map.insert(&input); |
686 | | test_map.insert(&input); // put it in twice |
687 | | let expected_output: ArrayRef = Arc::new(StringViewArray::from(input)); |
688 | | assert_eq!(&test_map.into_array(), &expected_output); |
689 | | } |
690 | | } |