Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/group_values/primitive.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use crate::aggregates::group_values::GroupValues;
19
use ahash::RandomState;
20
use arrow::array::BooleanBufferBuilder;
21
use arrow::buffer::NullBuffer;
22
use arrow::datatypes::i256;
23
use arrow::record_batch::RecordBatch;
24
use arrow_array::cast::AsArray;
25
use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray};
26
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
27
use arrow_schema::DataType;
28
use datafusion_common::Result;
29
use datafusion_execution::memory_pool::proxy::VecAllocExt;
30
use datafusion_expr::EmitTo;
31
use half::f16;
32
use hashbrown::raw::RawTable;
33
use std::sync::Arc;
34
35
/// A trait to allow hashing of floating point numbers
36
pub(crate) trait HashValue {
37
    fn hash(&self, state: &RandomState) -> u64;
38
}
39
40
macro_rules! hash_integer {
41
    ($($t:ty),+) => {
42
        $(impl HashValue for $t {
43
            #[cfg(not(feature = "force_hash_collisions"))]
44
343
            fn hash(&self, state: &RandomState) -> u64 {
45
343
                state.hash_one(self)
46
343
            }
47
48
            #[cfg(feature = "force_hash_collisions")]
49
            fn hash(&self, _state: &RandomState) -> u64 {
50
                0
51
            }
52
        })+
53
    };
54
}
55
hash_integer!(i8, i16, i32, i64, i128, i256);
56
hash_integer!(u8, u16, u32, u64);
57
hash_integer!(IntervalDayTime, IntervalMonthDayNano);
58
59
macro_rules! hash_float {
60
    ($($t:ty),+) => {
61
        $(impl HashValue for $t {
62
            #[cfg(not(feature = "force_hash_collisions"))]
63
0
            fn hash(&self, state: &RandomState) -> u64 {
64
0
                state.hash_one(self.to_bits())
65
0
            }
66
67
            #[cfg(feature = "force_hash_collisions")]
68
            fn hash(&self, _state: &RandomState) -> u64 {
69
                0
70
            }
71
        })+
72
    };
73
}
74
75
hash_float!(f16, f32, f64);
76
77
/// A [`GroupValues`] storing a single column of primitive values
78
///
79
/// This specialization is significantly faster than using the more general
80
/// purpose `Row`s format
81
pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
82
    /// The data type of the output array
83
    data_type: DataType,
84
    /// Stores the group index based on the hash of its value
85
    ///
86
    /// We don't store the hashes as hashing fixed width primitives
87
    /// is fast enough for this not to benefit performance
88
    map: RawTable<usize>,
89
    /// The group index of the null value if any
90
    null_group: Option<usize>,
91
    /// The values for each group index
92
    values: Vec<T::Native>,
93
    /// The random state used to generate hashes
94
    random_state: RandomState,
95
}
96
97
impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
98
56
    pub fn new(data_type: DataType) -> Self {
99
56
        assert!(PrimitiveArray::<T>::is_compatible(&data_type));
100
56
        Self {
101
56
            data_type,
102
56
            map: RawTable::with_capacity(128),
103
56
            values: Vec::with_capacity(128),
104
56
            null_group: None,
105
56
            random_state: Default::default(),
106
56
        }
107
56
    }
108
}
109
110
impl<T: ArrowPrimitiveType> GroupValues for GroupValuesPrimitive<T>
111
where
112
    T::Native: HashValue,
113
{
114
100
    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
115
100
        assert_eq!(cols.len(), 1);
116
100
        groups.clear();
117
118
337
        for v in 
cols[0].as_primitive::<T>()100
{
119
337
            let group_id = match v {
120
0
                None => *self.null_group.get_or_insert_with(|| {
121
0
                    let group_id = self.values.len();
122
0
                    self.values.push(Default::default());
123
0
                    group_id
124
0
                }),
125
337
                Some(key) => {
126
337
                    let state = &self.random_state;
127
337
                    let hash = key.hash(state);
128
337
                    let insert = self.map.find_or_find_insert_slot(
129
337
                        hash,
130
337
                        |g| 
unsafe { self.values.get_unchecked(*g).is_eq(key) }143
,
131
337
                        |g| 
unsafe { self.values.get_unchecked(*g).hash(state) }6
,
132
337
                    );
133
337
134
337
                    // SAFETY: No mutation occurred since find_or_find_insert_slot
135
337
                    unsafe {
136
337
                        match insert {
137
143
                            Ok(v) => *v.as_ref(),
138
194
                            Err(slot) => {
139
194
                                let g = self.values.len();
140
194
                                self.map.insert_in_slot(hash, slot, g);
141
194
                                self.values.push(key);
142
194
                                g
143
                            }
144
                        }
145
                    }
146
                }
147
            };
148
337
            groups.push(group_id)
149
        }
150
100
        Ok(())
151
100
    }
152
153
300
    fn size(&self) -> usize {
154
300
        self.map.capacity() * std::mem::size_of::<usize>() + self.values.allocated_size()
155
300
    }
156
157
94
    fn is_empty(&self) -> bool {
158
94
        self.values.is_empty()
159
94
    }
160
161
403
    fn len(&self) -> usize {
162
403
        self.values.len()
163
403
    }
164
165
94
    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
166
94
        fn build_primitive<T: ArrowPrimitiveType>(
167
94
            values: Vec<T::Native>,
168
94
            null_idx: Option<usize>,
169
94
        ) -> PrimitiveArray<T> {
170
94
            let nulls = null_idx.map(|null_idx| {
171
0
                let mut buffer = BooleanBufferBuilder::new(values.len());
172
0
                buffer.append_n(values.len(), true);
173
0
                buffer.set_bit(null_idx, false);
174
0
                unsafe { NullBuffer::new_unchecked(buffer.finish(), 1) }
175
94
            });
176
94
            PrimitiveArray::<T>::new(values.into(), nulls)
177
94
        }
178
179
94
        let array: PrimitiveArray<T> = match emit_to {
180
            EmitTo::All => {
181
62
                self.map.clear();
182
62
                build_primitive(std::mem::take(&mut self.values), self.null_group.take())
183
            }
184
32
            EmitTo::First(n) => {
185
                // SAFETY: self.map outlives iterator and is not modified concurrently
186
                unsafe {
187
88
                    for bucket in 
self.map.iter()32
{
188
                        // Decrement group index by n
189
88
                        match bucket.as_ref().checked_sub(n) {
190
                            // Group index was >= n, shift value down
191
32
                            Some(sub) => *bucket.as_mut() = sub,
192
                            // Group index was < n, so remove from table
193
56
                            None => self.map.erase(bucket),
194
                        }
195
                    }
196
                }
197
32
                let null_group = match &mut self.null_group {
198
0
                    Some(v) if *v >= n => {
199
0
                        *v -= n;
200
0
                        None
201
                    }
202
0
                    Some(_) => self.null_group.take(),
203
32
                    None => None,
204
                };
205
32
                let mut split = self.values.split_off(n);
206
32
                std::mem::swap(&mut self.values, &mut split);
207
32
                build_primitive(split, null_group)
208
            }
209
        };
210
94
        Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
211
94
    }
212
213
62
    fn clear_shrink(&mut self, batch: &RecordBatch) {
214
62
        let count = batch.num_rows();
215
62
        self.values.clear();
216
62
        self.values.shrink_to(count);
217
62
        self.map.clear();
218
62
        self.map.shrink_to(count, |_| 
00
); // hasher does not matter since the map is cleared
219
62
    }
220
}