Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr-common/src/binary_map.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`ArrowBytesMap`] and [`ArrowBytesSet`] for storing maps/sets of values from
19
//! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray.
20
21
use ahash::RandomState;
22
use arrow::array::cast::AsArray;
23
use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType};
24
use arrow::array::{
25
    Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray,
26
    GenericStringArray, OffsetSizeTrait,
27
};
28
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
29
use arrow::datatypes::DataType;
30
use datafusion_common::hash_utils::create_hashes;
31
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
32
use std::any::type_name;
33
use std::fmt::Debug;
34
use std::mem;
35
use std::ops::Range;
36
use std::sync::Arc;
37
38
/// Should the output be a String or Binary?
39
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40
pub enum OutputType {
41
    /// `StringArray` or `LargeStringArray`
42
    Utf8,
43
    /// `StringViewArray`
44
    Utf8View,
45
    /// `BinaryArray` or `LargeBinaryArray`
46
    Binary,
47
    /// `BinaryViewArray`
48
    BinaryView,
49
}
50
51
/// HashSet optimized for storing string or binary values that can produce that
52
/// the final set as a GenericStringArray with minimal copies.
53
#[derive(Debug)]
54
pub struct ArrowBytesSet<O: OffsetSizeTrait>(ArrowBytesMap<O, ()>);
55
56
impl<O: OffsetSizeTrait> ArrowBytesSet<O> {
57
0
    pub fn new(output_type: OutputType) -> Self {
58
0
        Self(ArrowBytesMap::new(output_type))
59
0
    }
60
61
    /// Return the contents of this set and replace it with a new empty
62
    /// set with the same output type
63
0
    pub fn take(&mut self) -> Self {
64
0
        Self(self.0.take())
65
0
    }
66
67
    /// Inserts each value from `values` into the set
68
0
    pub fn insert(&mut self, values: &ArrayRef) {
69
0
        fn make_payload_fn(_value: Option<&[u8]>) {}
70
0
        fn observe_payload_fn(_payload: ()) {}
71
0
        self.0
72
0
            .insert_if_new(values, make_payload_fn, observe_payload_fn);
73
0
    }
74
75
    /// Converts this set into a `StringArray`/`LargeStringArray` or
76
    /// `BinaryArray`/`LargeBinaryArray` containing each distinct value that
77
    /// was interned. This is done without copying the values.
78
0
    pub fn into_state(self) -> ArrayRef {
79
0
        self.0.into_state()
80
0
    }
81
82
    /// Returns the total number of distinct values (including nulls) seen so far
83
0
    pub fn len(&self) -> usize {
84
0
        self.0.len()
85
0
    }
86
87
0
    pub fn is_empty(&self) -> bool {
88
0
        self.0.is_empty()
89
0
    }
90
91
    /// returns the total number of distinct values (not including nulls) seen so far
92
0
    pub fn non_null_len(&self) -> usize {
93
0
        self.0.non_null_len()
94
0
    }
95
96
    /// Return the total size, in bytes, of memory used to store the data in
97
    /// this set, not including `self`
98
0
    pub fn size(&self) -> usize {
99
0
        self.0.size()
100
0
    }
101
}
102
103
/// Optimized map for storing Arrow "bytes" types (`String`, `LargeString`,
104
/// `Binary`, and `LargeBinary`) values that can produce the set of keys on
105
/// output as `GenericBinaryArray` without copies.
106
///
107
/// Equivalent to `HashSet<String, V>` but with better performance for arrow
108
/// data.
109
///
110
/// # Generic Arguments
111
///
112
/// * `O`: OffsetSize (String/LargeString)
113
/// * `V`: payload type
114
///
115
/// # Description
116
///
117
/// This is a specialized HashMap with the following properties:
118
///
119
/// 1. Optimized for storing and emitting Arrow byte types  (e.g.
120
///    `StringArray` / `BinaryArray`) very efficiently by minimizing copying of
121
///    the string values themselves, both when inserting and when emitting the
122
///    final array.
123
///
124
///
125
/// 2. Retains the insertion order of entries in the final array. The values are
126
///    in the same order as they were inserted.
127
///
128
/// Note this structure can be used as a `HashSet` by specifying the value type
129
/// as `()`, as is done by [`ArrowBytesSet`].
130
///
131
/// This map is used by the special `COUNT DISTINCT` aggregate function to
132
/// store the distinct values, and by the `GROUP BY` operator to store
133
/// group values when they are a single string array.
134
///
135
/// # Example
136
///
137
/// The following diagram shows how the map would store the four strings
138
/// "Foo", NULL, "Bar", "TheQuickBrownFox":
139
///
140
/// * `hashtable` stores entries for each distinct string that has been
141
///   inserted. The entries contain the payload as well as information about the
142
///   value (either an offset or the actual bytes, see `Entry` docs for more
143
///   details)
144
///
145
/// * `offsets` stores offsets into `buffer` for each distinct string value,
146
///   following the same convention as the offsets in a `StringArray` or
147
///   `LargeStringArray`.
148
///
149
/// * `buffer` stores the actual byte data
150
///
151
/// * `null`: stores the index and payload of the null value, in this case the
152
///   second value (index 1)
153
///
154
/// ```text
155
/// ┌───────────────────────────────────┐    ┌─────┐    ┌────┐
156
/// │                ...                │    │  0  │    │FooB│
157
/// │ ┌──────────────────────────────┐  │    │  0  │    │arTh│
158
/// │ │      <Entry for "Bar">       │  │    │  3  │    │eQui│
159
/// │ │            len: 3            │  │    │  3  │    │ckBr│
160
/// │ │   offset_or_inline: "Bar"    │  │    │  6  │    │ownF│
161
/// │ │         payload:...          │  │    │     │    │ox  │
162
/// │ └──────────────────────────────┘  │    │     │    │    │
163
/// │                ...                │    └─────┘    └────┘
164
/// │ ┌──────────────────────────────┐  │
165
/// │ │<Entry for "TheQuickBrownFox">│  │    offsets    buffer
166
/// │ │           len: 16            │  │
167
/// │ │     offset_or_inline: 6      │  │    ┌───────────────┐
168
/// │ │         payload: ...         │  │    │    Some(1)    │
169
/// │ └──────────────────────────────┘  │    │ payload: ...  │
170
/// │                ...                │    └───────────────┘
171
/// └───────────────────────────────────┘
172
///                                              null
173
///               HashTable
174
/// ```
175
///
176
/// # Entry Format
177
///
178
/// Entries stored in a [`ArrowBytesMap`] represents a value that is either
179
/// stored inline or in the buffer
180
///
181
/// This helps the case where there are many short (less than 8 bytes) strings
182
/// that are the same (e.g. "MA", "CA", "NY", "TX", etc)
183
///
184
/// ```text
185
///                                                                ┌──────────────────┐
186
///                                                  ─ ─ ─ ─ ─ ─ ─▶│...               │
187
///                                                 │              │TheQuickBrownFox  │
188
///                                                                │...               │
189
///                                                 │              │                  │
190
///                                                                └──────────────────┘
191
///                                                 │               buffer of u8
192
///
193
///                                                 │
194
///                        ┌────────────────┬───────────────┬───────────────┐
195
///  Storing               │                │ starting byte │  length, in   │
196
///  "TheQuickBrownFox"    │   hash value   │   offset in   │  bytes (not   │
197
///  (long string)         │                │    buffer     │  characters)  │
198
///                        └────────────────┴───────────────┴───────────────┘
199
///                              8 bytes          8 bytes       4 or 8
200
///
201
///
202
///                         ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐
203
/// Storing "foobar"        │               │ │ │ │ │ │ │ │ │  length, in   │
204
/// (short string)          │  hash value   │?│?│f│o│o│b│a│r│  bytes (not   │
205
///                         │               │ │ │ │ │ │ │ │ │  characters)  │
206
///                         └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘
207
///                              8 bytes         8 bytes        4 or 8
208
/// ```
209
pub struct ArrowBytesMap<O, V>
210
where
211
    O: OffsetSizeTrait,
212
    V: Debug + PartialEq + Eq + Clone + Copy + Default,
213
{
214
    /// Should the output be String or Binary?
215
    output_type: OutputType,
216
    /// Underlying hash set for each distinct value
217
    map: hashbrown::raw::RawTable<Entry<O, V>>,
218
    /// Total size of the map in bytes
219
    map_size: usize,
220
    /// In progress arrow `Buffer` containing all values
221
    buffer: BufferBuilder<u8>,
222
    /// Offsets into `buffer` for each distinct  value. These offsets as used
223
    /// directly to create the final `GenericBinaryArray`. The `i`th string is
224
    /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values
225
    /// are stored as a zero length string.
226
    offsets: Vec<O>,
227
    /// random state used to generate hashes
228
    random_state: RandomState,
229
    /// buffer that stores hash values (reused across batches to save allocations)
230
    hashes_buffer: Vec<u64>,
231
    /// `(payload, null_index)` for the 'null' value, if any
232
    /// NOTE null_index is the logical index in the final array, not the index
233
    /// in the buffer
234
    null: Option<(V, usize)>,
235
}
236
237
/// The size, in number of entries, of the initial hash table
238
const INITIAL_MAP_CAPACITY: usize = 128;
239
/// The initial size, in bytes, of the string data
240
pub const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024;
241
impl<O: OffsetSizeTrait, V> ArrowBytesMap<O, V>
242
where
243
    V: Debug + PartialEq + Eq + Clone + Copy + Default,
244
{
245
0
    pub fn new(output_type: OutputType) -> Self {
246
0
        Self {
247
0
            output_type,
248
0
            map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY),
249
0
            map_size: 0,
250
0
            buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
251
0
            offsets: vec![O::default()], // first offset is always 0
252
0
            random_state: RandomState::new(),
253
0
            hashes_buffer: vec![],
254
0
            null: None,
255
0
        }
256
0
    }
257
258
    /// Return the contents of this map and replace it with a new empty map with
259
    /// the same output type
260
0
    pub fn take(&mut self) -> Self {
261
0
        let mut new_self = Self::new(self.output_type);
262
0
        mem::swap(self, &mut new_self);
263
0
        new_self
264
0
    }
265
266
    /// Inserts each value from `values` into the map, invoking `payload_fn` for
267
    /// each value if *not* already present, deferring the allocation of the
268
    /// payload until it is needed.
269
    ///
270
    /// Note that this is different than a normal map that would replace the
271
    /// existing entry
272
    ///
273
    /// # Arguments:
274
    ///
275
    /// `values`: array whose values are inserted
276
    ///
277
    /// `make_payload_fn`:  invoked for each value that is not already present
278
    /// to create the payload, in order of the values in `values`
279
    ///
280
    /// `observe_payload_fn`: invoked once, for each value in `values`, that was
281
    /// already present in the map, with corresponding payload value.
282
    ///
283
    /// # Returns
284
    ///
285
    /// The payload value for the entry, either the existing value or
286
    /// the newly inserted value
287
    ///
288
    /// # Safety:
289
    ///
290
    /// Note that `make_payload_fn` and `observe_payload_fn` are only invoked
291
    /// with valid values from `values`, not for the `NULL` value.
292
0
    pub fn insert_if_new<MP, OP>(
293
0
        &mut self,
294
0
        values: &ArrayRef,
295
0
        make_payload_fn: MP,
296
0
        observe_payload_fn: OP,
297
0
    ) where
298
0
        MP: FnMut(Option<&[u8]>) -> V,
299
0
        OP: FnMut(V),
300
0
    {
301
0
        // Sanity array type
302
0
        match self.output_type {
303
            OutputType::Binary => {
304
0
                assert!(matches!(
305
0
                    values.data_type(),
306
                    DataType::Binary | DataType::LargeBinary
307
                ));
308
0
                self.insert_if_new_inner::<MP, OP, GenericBinaryType<O>>(
309
0
                    values,
310
0
                    make_payload_fn,
311
0
                    observe_payload_fn,
312
0
                )
313
            }
314
            OutputType::Utf8 => {
315
0
                assert!(matches!(
316
0
                    values.data_type(),
317
                    DataType::Utf8 | DataType::LargeUtf8
318
                ));
319
0
                self.insert_if_new_inner::<MP, OP, GenericStringType<O>>(
320
0
                    values,
321
0
                    make_payload_fn,
322
0
                    observe_payload_fn,
323
0
                )
324
            }
325
0
            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
326
        };
327
0
    }
328
329
    /// Generic version of [`Self::insert_if_new`] that handles `ByteArrayType`
330
    /// (both String and Binary)
331
    ///
332
    /// Note this is the only function that is generic on [`ByteArrayType`], which
333
    /// avoids having to template the entire structure,  making the code
334
    /// simpler and understand and reducing code bloat due to duplication.
335
    ///
336
    /// See comments on `insert_if_new` for more details
337
0
    fn insert_if_new_inner<MP, OP, B>(
338
0
        &mut self,
339
0
        values: &ArrayRef,
340
0
        mut make_payload_fn: MP,
341
0
        mut observe_payload_fn: OP,
342
0
    ) where
343
0
        MP: FnMut(Option<&[u8]>) -> V,
344
0
        OP: FnMut(V),
345
0
        B: ByteArrayType,
346
0
    {
347
0
        // step 1: compute hashes
348
0
        let batch_hashes = &mut self.hashes_buffer;
349
0
        batch_hashes.clear();
350
0
        batch_hashes.resize(values.len(), 0);
351
0
        create_hashes(&[values.clone()], &self.random_state, batch_hashes)
352
0
            // hash is supported for all types and create_hashes only
353
0
            // returns errors for unsupported types
354
0
            .unwrap();
355
0
356
0
        // step 2: insert each value into the set, if not already present
357
0
        let values = values.as_bytes::<B>();
358
0
359
0
        // Ensure lengths are equivalent
360
0
        assert_eq!(values.len(), batch_hashes.len());
361
362
0
        for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
363
            // handle null value
364
0
            let Some(value) = value else {
365
0
                let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
366
0
                    payload
367
                } else {
368
0
                    let payload = make_payload_fn(None);
369
0
                    let null_index = self.offsets.len() - 1;
370
0
                    // nulls need a zero length in the offset buffer
371
0
                    let offset = self.buffer.len();
372
0
                    self.offsets.push(O::usize_as(offset));
373
0
                    self.null = Some((payload, null_index));
374
0
                    payload
375
                };
376
0
                observe_payload_fn(payload);
377
0
                continue;
378
            };
379
380
            // get the value as bytes
381
0
            let value: &[u8] = value.as_ref();
382
0
            let value_len = O::usize_as(value.len());
383
384
            // value is "small"
385
0
            let payload = if value.len() <= SHORT_VALUE_LEN {
386
0
                let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize);
387
0
388
0
                // is value is already present in the set?
389
0
                let entry = self.map.get_mut(hash, |header| {
390
0
                    // compare value if hashes match
391
0
                    if header.len != value_len {
392
0
                        return false;
393
0
                    }
394
0
                    // value is stored inline so no need to consult buffer
395
0
                    // (this is the "small string optimization")
396
0
                    inline == header.offset_or_inline
397
0
                });
398
399
0
                if let Some(entry) = entry {
400
0
                    entry.payload
401
                }
402
                // if no existing entry, make a new one
403
                else {
404
                    // Put the small values into buffer and offsets so it appears
405
                    // the output array, but store the actual bytes inline for
406
                    // comparison
407
0
                    self.buffer.append_slice(value);
408
0
                    self.offsets.push(O::usize_as(self.buffer.len()));
409
0
                    let payload = make_payload_fn(Some(value));
410
0
                    let new_header = Entry {
411
0
                        hash,
412
0
                        len: value_len,
413
0
                        offset_or_inline: inline,
414
0
                        payload,
415
0
                    };
416
0
                    self.map.insert_accounted(
417
0
                        new_header,
418
0
                        |header| header.hash,
419
0
                        &mut self.map_size,
420
0
                    );
421
0
                    payload
422
                }
423
            }
424
            // value is not "small"
425
            else {
426
                // Check if the value is already present in the set
427
0
                let entry = self.map.get_mut(hash, |header| {
428
0
                    // compare value if hashes match
429
0
                    if header.len != value_len {
430
0
                        return false;
431
0
                    }
432
0
                    // Need to compare the bytes in the buffer
433
0
                    // SAFETY: buffer is only appended to, and we correctly inserted values and offsets
434
0
                    let existing_value =
435
0
                        unsafe { self.buffer.as_slice().get_unchecked(header.range()) };
436
0
                    value == existing_value
437
0
                });
438
439
0
                if let Some(entry) = entry {
440
0
                    entry.payload
441
                }
442
                // if no existing entry, make a new one
443
                else {
444
                    // Put the small values into buffer and offsets so it
445
                    // appears the output array, and store that offset
446
                    // so the bytes can be compared if needed
447
0
                    let offset = self.buffer.len(); // offset of start for data
448
0
                    self.buffer.append_slice(value);
449
0
                    self.offsets.push(O::usize_as(self.buffer.len()));
450
0
451
0
                    let payload = make_payload_fn(Some(value));
452
0
                    let new_header = Entry {
453
0
                        hash,
454
0
                        len: value_len,
455
0
                        offset_or_inline: offset,
456
0
                        payload,
457
0
                    };
458
0
                    self.map.insert_accounted(
459
0
                        new_header,
460
0
                        |header| header.hash,
461
0
                        &mut self.map_size,
462
0
                    );
463
0
                    payload
464
                }
465
            };
466
0
            observe_payload_fn(payload);
467
        }
468
        // Check for overflow in offsets (if more data was sent than can be represented)
469
0
        if O::from_usize(self.buffer.len()).is_none() {
470
0
            panic!(
471
0
                "Put {} bytes in buffer, more than can be represented by a {}",
472
0
                self.buffer.len(),
473
0
                type_name::<O>()
474
0
            );
475
0
        }
476
0
    }
477
478
    /// Converts this set into a `StringArray`, `LargeStringArray`,
479
    /// `BinaryArray`, or `LargeBinaryArray` containing each distinct value
480
    /// that was inserted. This is done without copying the values.
481
    ///
482
    /// The values are guaranteed to be returned in the same order in which
483
    /// they were first seen.
484
0
    pub fn into_state(self) -> ArrayRef {
485
0
        let Self {
486
0
            output_type,
487
0
            map: _,
488
0
            map_size: _,
489
0
            offsets,
490
0
            mut buffer,
491
0
            random_state: _,
492
0
            hashes_buffer: _,
493
0
            null,
494
0
        } = self;
495
0
496
0
        // Only make a `NullBuffer` if there was a null value
497
0
        let nulls = null.map(|(_payload, null_index)| {
498
0
            let num_values = offsets.len() - 1;
499
0
            single_null_buffer(num_values, null_index)
500
0
        });
501
0
        // SAFETY: the offsets were constructed correctly in `insert_if_new` --
502
0
        // monotonically increasing, overflows were checked.
503
0
        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
504
0
        let values = buffer.finish();
505
0
506
0
        match output_type {
507
            OutputType::Binary => {
508
                // SAFETY: the offsets were constructed correctly
509
0
                Arc::new(unsafe {
510
0
                    GenericBinaryArray::new_unchecked(offsets, values, nulls)
511
0
                })
512
            }
513
            OutputType::Utf8 => {
514
                // SAFETY:
515
                // 1. the offsets were constructed safely
516
                //
517
                // 2. we asserted the input arrays were all the correct type and
518
                // thus since all the values that went in were valid (e.g. utf8)
519
                // so are all the values that come out
520
0
                Arc::new(unsafe {
521
0
                    GenericStringArray::new_unchecked(offsets, values, nulls)
522
0
                })
523
            }
524
0
            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
525
        }
526
0
    }
527
528
    /// Total number of entries (including null, if present)
529
0
    pub fn len(&self) -> usize {
530
0
        self.non_null_len() + self.null.map(|_| 1).unwrap_or(0)
531
0
    }
532
533
    /// Is the set empty?
534
0
    pub fn is_empty(&self) -> bool {
535
0
        self.map.is_empty() && self.null.is_none()
536
0
    }
537
538
    /// Number of non null entries
539
0
    pub fn non_null_len(&self) -> usize {
540
0
        self.map.len()
541
0
    }
542
543
    /// Return the total size, in bytes, of memory used to store the data in
544
    /// this set, not including `self`
545
0
    pub fn size(&self) -> usize {
546
0
        self.map_size
547
0
            + self.buffer.capacity() * mem::size_of::<u8>()
548
0
            + self.offsets.allocated_size()
549
0
            + self.hashes_buffer.allocated_size()
550
0
    }
551
}
552
553
/// Returns a `NullBuffer` with a single null value at the given index
554
0
fn single_null_buffer(num_values: usize, null_index: usize) -> NullBuffer {
555
0
    let mut bool_builder = BooleanBufferBuilder::new(num_values);
556
0
    bool_builder.append_n(num_values, true);
557
0
    bool_builder.set_bit(null_index, false);
558
0
    NullBuffer::from(bool_builder.finish())
559
0
}
560
561
impl<O: OffsetSizeTrait, V> Debug for ArrowBytesMap<O, V>
562
where
563
    V: Debug + PartialEq + Eq + Clone + Copy + Default,
564
{
565
0
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566
0
        f.debug_struct("ArrowBytesMap")
567
0
            .field("map", &"<map>")
568
0
            .field("map_size", &self.map_size)
569
0
            .field("buffer", &self.buffer)
570
0
            .field("random_state", &self.random_state)
571
0
            .field("hashes_buffer", &self.hashes_buffer)
572
0
            .finish()
573
0
    }
574
}
575
576
/// Maximum size of a value that can be inlined in the hash table
577
const SHORT_VALUE_LEN: usize = mem::size_of::<usize>();
578
579
/// Entry in the hash table -- see [`ArrowBytesMap`] for more details
580
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
581
struct Entry<O, V>
582
where
583
    O: OffsetSizeTrait,
584
    V: Debug + PartialEq + Eq + Clone + Copy + Default,
585
{
586
    /// hash of the value (stored to avoid recomputing it in hash table check)
587
    hash: u64,
588
    /// if len =< [`SHORT_VALUE_LEN`]: the data inlined
589
    /// if len > [`SHORT_VALUE_LEN`], the offset of where the data starts
590
    offset_or_inline: usize,
591
    /// length of the value, in bytes (use O here so we use only i32 for
592
    /// strings, rather 64 bit usize)
593
    len: O,
594
    /// value stored by the entry
595
    payload: V,
596
}
597
598
impl<O, V> Entry<O, V>
599
where
600
    O: OffsetSizeTrait,
601
    V: Debug + PartialEq + Eq + Clone + Copy + Default,
602
{
603
    /// returns self.offset..self.offset + self.len
604
    #[inline(always)]
605
0
    fn range(&self) -> Range<usize> {
606
0
        self.offset_or_inline..self.offset_or_inline + self.len.as_usize()
607
0
    }
608
}
609
610
#[cfg(test)]
611
mod tests {
612
    use super::*;
613
    use arrow::array::{BinaryArray, LargeBinaryArray, StringArray};
614
    use std::collections::HashMap;
615
616
    #[test]
617
    fn string_set_empty() {
618
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
619
        let array: ArrayRef = Arc::new(StringArray::new_null(0));
620
        set.insert(&array);
621
        assert_eq!(set.len(), 0);
622
        assert_eq!(set.non_null_len(), 0);
623
        assert_set(set, &[]);
624
    }
625
626
    #[test]
627
    fn string_set_one_null() {
628
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
629
        let array: ArrayRef = Arc::new(StringArray::new_null(1));
630
        set.insert(&array);
631
        assert_eq!(set.len(), 1);
632
        assert_eq!(set.non_null_len(), 0);
633
        assert_set(set, &[None]);
634
    }
635
636
    #[test]
637
    fn string_set_many_null() {
638
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
639
        let array: ArrayRef = Arc::new(StringArray::new_null(11));
640
        set.insert(&array);
641
        assert_eq!(set.len(), 1);
642
        assert_eq!(set.non_null_len(), 0);
643
        assert_set(set, &[None]);
644
    }
645
646
    #[test]
647
    fn string_set_basic_i32() {
648
        test_string_set_basic::<i32>();
649
    }
650
651
    #[test]
652
    fn string_set_basic_i64() {
653
        test_string_set_basic::<i64>();
654
    }
655
656
    fn test_string_set_basic<O: OffsetSizeTrait>() {
657
        // basic test for mixed small and large string values
658
        let values = GenericStringArray::<O>::from(vec![
659
            Some("a"),
660
            Some("b"),
661
            Some("CXCCCCCCCC"), // 10 bytes
662
            Some(""),
663
            Some("cbcxx"), // 5 bytes
664
            None,
665
            Some("AAAAAAAA"),  // 8 bytes
666
            Some("BBBBBQBBB"), // 9 bytes
667
            Some("a"),
668
            Some("cbcxx"),
669
            Some("b"),
670
            Some("cbcxx"),
671
            Some(""),
672
            None,
673
            Some("BBBBBQBBB"),
674
            Some("BBBBBQBBB"),
675
            Some("AAAAAAAA"),
676
            Some("CXCCCCCCCC"),
677
        ]);
678
679
        let mut set = ArrowBytesSet::<O>::new(OutputType::Utf8);
680
        let array: ArrayRef = Arc::new(values);
681
        set.insert(&array);
682
        // values mut appear be in the order they were inserted
683
        assert_set(
684
            set,
685
            &[
686
                Some("a"),
687
                Some("b"),
688
                Some("CXCCCCCCCC"),
689
                Some(""),
690
                Some("cbcxx"),
691
                None,
692
                Some("AAAAAAAA"),
693
                Some("BBBBBQBBB"),
694
            ],
695
        );
696
    }
697
698
    #[test]
699
    fn string_set_non_utf8_32() {
700
        test_string_set_non_utf8::<i32>();
701
    }
702
703
    #[test]
704
    fn string_set_non_utf8_64() {
705
        test_string_set_non_utf8::<i64>();
706
    }
707
708
    fn test_string_set_non_utf8<O: OffsetSizeTrait>() {
709
        // basic test for mixed small and large string values
710
        let values = GenericStringArray::<O>::from(vec![
711
            Some("a"),
712
            Some("✨🔥"),
713
            Some("🔥"),
714
            Some("✨✨✨"),
715
            Some("foobarbaz"),
716
            Some("🔥"),
717
            Some("✨🔥"),
718
        ]);
719
720
        let mut set = ArrowBytesSet::<O>::new(OutputType::Utf8);
721
        let array: ArrayRef = Arc::new(values);
722
        set.insert(&array);
723
        // strings mut appear be in the order they were inserted
724
        assert_set(
725
            set,
726
            &[
727
                Some("a"),
728
                Some("✨🔥"),
729
                Some("🔥"),
730
                Some("✨✨✨"),
731
                Some("foobarbaz"),
732
            ],
733
        );
734
    }
735
736
    // asserts that the set contains the expected strings, in the same order
737
    fn assert_set<O: OffsetSizeTrait>(set: ArrowBytesSet<O>, expected: &[Option<&str>]) {
738
        let strings = set.into_state();
739
        let strings = strings.as_string::<O>();
740
        let state = strings.into_iter().collect::<Vec<_>>();
741
        assert_eq!(state, expected);
742
    }
743
744
    // Test use of binary output type
745
    #[test]
746
    fn test_binary_set() {
747
        let values: ArrayRef = Arc::new(BinaryArray::from_opt_vec(vec![
748
            Some(b"a"),
749
            Some(b"CXCCCCCCCC"),
750
            None,
751
            Some(b"CXCCCCCCCC"),
752
        ]));
753
754
        let expected: ArrayRef = Arc::new(BinaryArray::from_opt_vec(vec![
755
            Some(b"a"),
756
            Some(b"CXCCCCCCCC"),
757
            None,
758
        ]));
759
760
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Binary);
761
        set.insert(&values);
762
        assert_eq!(&set.into_state(), &expected);
763
    }
764
765
    // Test use of binary output type
766
    #[test]
767
    fn test_large_binary_set() {
768
        let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![
769
            Some(b"a"),
770
            Some(b"CXCCCCCCCC"),
771
            None,
772
            Some(b"CXCCCCCCCC"),
773
        ]));
774
775
        let expected: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![
776
            Some(b"a"),
777
            Some(b"CXCCCCCCCC"),
778
            None,
779
        ]));
780
781
        let mut set = ArrowBytesSet::<i64>::new(OutputType::Binary);
782
        set.insert(&values);
783
        assert_eq!(&set.into_state(), &expected);
784
    }
785
786
    #[test]
787
    #[should_panic(
788
        expected = "matches!(values.data_type(), DataType::Utf8 | DataType::LargeUtf8)"
789
    )]
790
    fn test_mismatched_types() {
791
        // inserting binary into a set that expects strings should panic
792
        let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![Some(b"a")]));
793
794
        let mut set = ArrowBytesSet::<i64>::new(OutputType::Utf8);
795
        set.insert(&values);
796
    }
797
798
    #[test]
799
    #[should_panic]
800
    fn test_mismatched_sizes() {
801
        // inserting large strings into a set that expects small should panic
802
        let values: ArrayRef = Arc::new(LargeBinaryArray::from_opt_vec(vec![Some(b"a")]));
803
804
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Binary);
805
        set.insert(&values);
806
    }
807
808
    // put more than 2GB in a string set and expect it to panic
809
    #[test]
810
    #[should_panic(
811
        expected = "Put 2147483648 bytes in buffer, more than can be represented by a i32"
812
    )]
813
    fn test_string_overflow() {
814
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
815
        for value in ["a", "b", "c"] {
816
            // 1GB strings, so 3rd is over 2GB and should panic
817
            let arr: ArrayRef =
818
                Arc::new(StringArray::from_iter_values([value.repeat(1 << 30)]));
819
            set.insert(&arr);
820
        }
821
    }
822
823
    // inserting strings into the set does not increase reported memory
824
    #[test]
825
    fn test_string_set_memory_usage() {
826
        let strings1 = GenericStringArray::<i32>::from(vec![
827
            Some("a"),
828
            Some("b"),
829
            Some("CXCCCCCCCC"), // 10 bytes
830
            Some("AAAAAAAA"),   // 8 bytes
831
            Some("BBBBBQBBB"),  // 9 bytes
832
        ]);
833
        let total_strings1_len = strings1
834
            .iter()
835
            .map(|s| s.map(|s| s.len()).unwrap_or(0))
836
            .sum::<usize>();
837
        let values1: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings1));
838
839
        // Much larger strings in strings2
840
        let strings2 = GenericStringArray::<i32>::from(vec![
841
            "FOO".repeat(1000),
842
            "BAR".repeat(2000),
843
            "BAZ".repeat(3000),
844
        ]);
845
        let total_strings2_len = strings2
846
            .iter()
847
            .map(|s| s.map(|s| s.len()).unwrap_or(0))
848
            .sum::<usize>();
849
        let values2: ArrayRef = Arc::new(GenericStringArray::<i32>::from(strings2));
850
851
        let mut set = ArrowBytesSet::<i32>::new(OutputType::Utf8);
852
        let size_empty = set.size();
853
854
        set.insert(&values1);
855
        let size_after_values1 = set.size();
856
        assert!(size_empty < size_after_values1);
857
        assert!(
858
            size_after_values1 > total_strings1_len,
859
            "expect {size_after_values1} to be more than {total_strings1_len}"
860
        );
861
        assert!(size_after_values1 < total_strings1_len + total_strings2_len);
862
863
        // inserting the same strings should not affect the size
864
        set.insert(&values1);
865
        assert_eq!(set.size(), size_after_values1);
866
867
        // inserting the large strings should increase the reported size
868
        set.insert(&values2);
869
        let size_after_values2 = set.size();
870
        assert!(size_after_values2 > size_after_values1);
871
        assert!(size_after_values2 > total_strings1_len + total_strings2_len);
872
    }
873
874
    #[test]
875
    fn test_map() {
876
        let input = vec![
877
            // Note mix of short/long strings
878
            Some("A"),
879
            Some("bcdefghijklmnop"),
880
            Some("X"),
881
            Some("Y"),
882
            None,
883
            Some("qrstuvqxyzhjwya"),
884
            Some("✨🔥"),
885
            Some("🔥"),
886
            Some("🔥🔥🔥🔥🔥🔥"),
887
        ];
888
889
        let mut test_map = TestMap::new();
890
        test_map.insert(&input);
891
        test_map.insert(&input); // put it in twice
892
        let expected_output: ArrayRef = Arc::new(StringArray::from(input));
893
        assert_eq!(&test_map.into_array(), &expected_output);
894
    }
895
896
    #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
897
    struct TestPayload {
898
        // store the string value to check against input
899
        index: usize, // store the index of the string (each new string gets the next sequential input)
900
    }
901
902
    /// Wraps an [`ArrowBytesMap`], validating its invariants
903
    struct TestMap {
904
        map: ArrowBytesMap<i32, TestPayload>,
905
        // stores distinct strings seen, in order
906
        strings: Vec<Option<String>>,
907
        // map strings to index in strings
908
        indexes: HashMap<Option<String>, usize>,
909
    }
910
911
    impl Debug for TestMap {
912
        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
913
            f.debug_struct("TestMap")
914
                .field("map", &"...")
915
                .field("strings", &self.strings)
916
                .field("indexes", &self.indexes)
917
                .finish()
918
        }
919
    }
920
921
    impl TestMap {
922
        /// creates a map with TestPayloads for the given strings and then
923
        /// validates the payloads
924
        fn new() -> Self {
925
            Self {
926
                map: ArrowBytesMap::new(OutputType::Utf8),
927
                strings: vec![],
928
                indexes: HashMap::new(),
929
            }
930
        }
931
932
        /// Inserts strings into the map
933
        fn insert(&mut self, strings: &[Option<&str>]) {
934
            let string_array = StringArray::from(strings.to_vec());
935
            let arr: ArrayRef = Arc::new(string_array);
936
937
            let mut next_index = self.indexes.len();
938
            let mut actual_new_strings = vec![];
939
            let mut actual_seen_indexes = vec![];
940
            // update self with new values, keeping track of newly added values
941
            for str in strings {
942
                let str = str.map(|s| s.to_string());
943
                let index = self.indexes.get(&str).cloned().unwrap_or_else(|| {
944
                    actual_new_strings.push(str.clone());
945
                    let index = self.strings.len();
946
                    self.strings.push(str.clone());
947
                    self.indexes.insert(str, index);
948
                    index
949
                });
950
                actual_seen_indexes.push(index);
951
            }
952
953
            // insert the values into the map, recording what we did
954
            let mut seen_new_strings = vec![];
955
            let mut seen_indexes = vec![];
956
            self.map.insert_if_new(
957
                &arr,
958
                |s| {
959
                    let value = s
960
                        .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string"));
961
                    let index = next_index;
962
                    next_index += 1;
963
                    seen_new_strings.push(value);
964
                    TestPayload { index }
965
                },
966
                |payload| {
967
                    seen_indexes.push(payload.index);
968
                },
969
            );
970
971
            assert_eq!(actual_seen_indexes, seen_indexes);
972
            assert_eq!(actual_new_strings, seen_new_strings);
973
        }
974
975
        /// Call `self.map.into_array()` validating that the strings are in the same
976
        /// order as they were inserted
977
        fn into_array(self) -> ArrayRef {
978
            let Self {
979
                map,
980
                strings,
981
                indexes: _,
982
            } = self;
983
984
            let arr = map.into_state();
985
            let expected: ArrayRef = Arc::new(StringArray::from(strings));
986
            assert_eq!(&arr, &expected);
987
            arr
988
        }
989
    }
990
}