Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/group_values/column.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use crate::aggregates::group_values::group_column::{
19
    ByteGroupValueBuilder, ByteViewGroupValueBuilder, GroupColumn,
20
    PrimitiveGroupValueBuilder,
21
};
22
use crate::aggregates::group_values::GroupValues;
23
use ahash::RandomState;
24
use arrow::compute::cast;
25
use arrow::datatypes::{
26
    BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type,
27
    Int32Type, Int64Type, Int8Type, StringViewType, UInt16Type, UInt32Type, UInt64Type,
28
    UInt8Type,
29
};
30
use arrow::record_batch::RecordBatch;
31
use arrow_array::{Array, ArrayRef};
32
use arrow_schema::{DataType, Schema, SchemaRef};
33
use datafusion_common::hash_utils::create_hashes;
34
use datafusion_common::{not_impl_err, DataFusionError, Result};
35
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
36
use datafusion_expr::EmitTo;
37
use datafusion_physical_expr::binary_map::OutputType;
38
39
use hashbrown::raw::RawTable;
40
41
/// A [`GroupValues`] that stores multiple columns of group values.
42
///
43
///
44
pub struct GroupValuesColumn {
45
    /// The output schema
46
    schema: SchemaRef,
47
48
    /// Logically maps group values to a group_index in
49
    /// [`Self::group_values`] and in each accumulator
50
    ///
51
    /// Uses the raw API of hashbrown to avoid actually storing the
52
    /// keys (group values) in the table
53
    ///
54
    /// keys: u64 hashes of the GroupValue
55
    /// values: (hash, group_index)
56
    map: RawTable<(u64, usize)>,
57
58
    /// The size of `map` in bytes
59
    map_size: usize,
60
61
    /// The actual group by values, stored column-wise. Compare from
62
    /// the left to right, each column is stored as [`GroupColumn`].
63
    ///
64
    /// Performance tests showed that this design is faster than using the
65
    /// more general purpose [`GroupValuesRows`]. See the ticket for details:
66
    /// <https://github.com/apache/datafusion/pull/12269>
67
    ///
68
    /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows
69
    group_values: Vec<Box<dyn GroupColumn>>,
70
71
    /// reused buffer to store hashes
72
    hashes_buffer: Vec<u64>,
73
74
    /// Random state for creating hashes
75
    random_state: RandomState,
76
}
77
78
impl GroupValuesColumn {
79
    /// Create a new instance of GroupValuesColumn if supported for the specified schema
80
13
    pub fn try_new(schema: SchemaRef) -> Result<Self> {
81
13
        let map = RawTable::with_capacity(0);
82
13
        Ok(Self {
83
13
            schema,
84
13
            map,
85
13
            map_size: 0,
86
13
            group_values: vec![],
87
13
            hashes_buffer: Default::default(),
88
13
            random_state: Default::default(),
89
13
        })
90
13
    }
91
92
    /// Returns true if [`GroupValuesColumn`] supported for the specified schema
93
14
    pub fn supported_schema(schema: &Schema) -> bool {
94
14
        schema
95
14
            .fields()
96
14
            .iter()
97
28
            .map(|f| f.data_type())
98
14
            .all(Self::supported_type)
99
14
    }
100
101
    /// Returns true if the specified data type is supported by [`GroupValuesColumn`]
102
    ///
103
    /// In order to be supported, there must be a specialized implementation of
104
    /// [`GroupColumn`] for the data type, instantiated in [`Self::intern`]
105
28
    fn supported_type(data_type: &DataType) -> bool {
106
1
        matches!(
107
28
            *data_type,
108
            DataType::Int8
109
                | DataType::Int16
110
                | DataType::Int32
111
                | DataType::Int64
112
                | DataType::UInt8
113
                | DataType::UInt16
114
                | DataType::UInt32
115
                | DataType::UInt64
116
                | DataType::Float32
117
                | DataType::Float64
118
                | DataType::Utf8
119
                | DataType::LargeUtf8
120
                | DataType::Binary
121
                | DataType::LargeBinary
122
                | DataType::Date32
123
                | DataType::Date64
124
                | DataType::Utf8View
125
                | DataType::BinaryView
126
        )
127
28
    }
128
}
129
130
/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v
131
///
132
/// Arguments:
133
/// `$v`: the vector to push the new builder into
134
/// `$nullable`: whether the input can contains nulls
135
/// `$t`: the primitive type of the builder
136
///
137
macro_rules! instantiate_primitive {
138
    ($v:expr, $nullable:expr, $t:ty) => {
139
        if $nullable {
140
            let b = PrimitiveGroupValueBuilder::<$t, true>::new();
141
            $v.push(Box::new(b) as _)
142
        } else {
143
            let b = PrimitiveGroupValueBuilder::<$t, false>::new();
144
            $v.push(Box::new(b) as _)
145
        }
146
    };
147
}
148
149
impl GroupValues for GroupValuesColumn {
150
68
    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
151
68
        let n_rows = cols[0].len();
152
68
153
68
        if self.group_values.is_empty() {
154
13
            let mut v = Vec::with_capacity(cols.len());
155
156
27
            for f in 
self.schema.fields().iter()13
{
157
27
                let nullable = f.is_nullable();
158
27
                match f.data_type() {
159
0
                    &DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type),
160
0
                    &DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type),
161
1
                    &DataType::Int32 => instantiate_primitive!(
v0
, nullable, Int32Type),
162
0
                    &DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type),
163
0
                    &DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type),
164
0
                    &DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type),
165
12
                    &DataType::UInt32 => instantiate_primitive!(
v0
, nullable, UInt32Type),
166
0
                    &DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type),
167
                    &DataType::Float32 => {
168
2
                        instantiate_primitive!(
v0
, nullable, Float32Type)
169
                    }
170
                    &DataType::Float64 => {
171
12
                        instantiate_primitive!(
v0
, nullable, Float64Type)
172
                    }
173
0
                    &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type),
174
0
                    &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type),
175
                    &DataType::Utf8 => {
176
0
                        let b = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
177
0
                        v.push(Box::new(b) as _)
178
                    }
179
                    &DataType::LargeUtf8 => {
180
0
                        let b = ByteGroupValueBuilder::<i64>::new(OutputType::Utf8);
181
0
                        v.push(Box::new(b) as _)
182
                    }
183
                    &DataType::Binary => {
184
0
                        let b = ByteGroupValueBuilder::<i32>::new(OutputType::Binary);
185
0
                        v.push(Box::new(b) as _)
186
                    }
187
                    &DataType::LargeBinary => {
188
0
                        let b = ByteGroupValueBuilder::<i64>::new(OutputType::Binary);
189
0
                        v.push(Box::new(b) as _)
190
                    }
191
                    &DataType::Utf8View => {
192
0
                        let b = ByteViewGroupValueBuilder::<StringViewType>::new();
193
0
                        v.push(Box::new(b) as _)
194
                    }
195
                    &DataType::BinaryView => {
196
0
                        let b = ByteViewGroupValueBuilder::<BinaryViewType>::new();
197
0
                        v.push(Box::new(b) as _)
198
                    }
199
0
                    dt => {
200
0
                        return not_impl_err!("{dt} not supported in GroupValuesColumn")
201
                    }
202
                }
203
            }
204
13
            self.group_values = v;
205
55
        }
206
207
        // tracks to which group each of the input rows belongs
208
68
        groups.clear();
209
68
210
68
        // 1.1 Calculate the group keys for the group values
211
68
        let batch_hashes = &mut self.hashes_buffer;
212
68
        batch_hashes.clear();
213
68
        batch_hashes.resize(n_rows, 0);
214
68
        create_hashes(cols, &self.random_state, batch_hashes)
?0
;
215
216
98.5k
        for (row, &target_hash) in 
batch_hashes.iter().enumerate()68
{
217
98.5k
            let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| {
218
98.3k
                // Somewhat surprisingly, this closure can be called even if the
219
98.3k
                // hash doesn't match, so check the hash first with an integer
220
98.3k
                // comparison first avoid the more expensive comparison with
221
98.3k
                // group value. https://github.com/apache/datafusion/pull/11718
222
98.3k
                if target_hash != *exist_hash {
223
1
                    return false;
224
98.3k
                }
225
226
295k
                fn check_row_equal(
227
295k
                    array_row: &dyn GroupColumn,
228
295k
                    lhs_row: usize,
229
295k
                    array: &ArrayRef,
230
295k
                    rhs_row: usize,
231
295k
                ) -> bool {
232
295k
                    array_row.equal_to(lhs_row, array, rhs_row)
233
295k
                }
234
235
295k
                for (i, group_val) in 
self.group_values.iter().enumerate()98.3k
{
236
295k
                    if !check_row_equal(group_val.as_ref(), *group_idx, &cols[i], row) {
237
0
                        return false;
238
295k
                    }
239
                }
240
241
98.3k
                true
242
98.5k
            
}98.3k
);
243
244
98.5k
            let group_idx = match entry {
245
                // Existing group_index for this group value
246
98.3k
                Some((_hash, group_idx)) => *group_idx,
247
                //  1.2 Need to create new entry for the group
248
                None => {
249
                    // Add new entry to aggr_state and save newly created index
250
                    // let group_idx = group_values.num_rows();
251
                    // group_values.push(group_rows.row(row));
252
253
163
                    let mut checklen = 0;
254
163
                    let group_idx = self.group_values[0].len();
255
329
                    for (i, group_value) in 
self.group_values.iter_mut().enumerate()163
{
256
329
                        group_value.append_val(&cols[i], row);
257
329
                        let len = group_value.len();
258
329
                        if i == 0 {
259
163
                            checklen = len;
260
163
                        } else {
261
166
                            debug_assert_eq!(checklen, len);
262
                        }
263
                    }
264
265
                    // for hasher function, use precomputed hash value
266
163
                    self.map.insert_accounted(
267
163
                        (target_hash, group_idx),
268
163
                        |(hash, _group_index)| *hash,
269
163
                        &mut self.map_size,
270
163
                    );
271
163
                    group_idx
272
                }
273
            };
274
98.5k
            groups.push(group_idx);
275
        }
276
277
68
        Ok(())
278
68
    }
279
280
68
    fn size(&self) -> usize {
281
92
        let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum();
282
68
        group_values_size + self.map_size + self.hashes_buffer.allocated_size()
283
68
    }
284
285
17
    fn is_empty(&self) -> bool {
286
17
        self.len() == 0
287
17
    }
288
289
189
    fn len(&self) -> usize {
290
189
        if self.group_values.is_empty() {
291
17
            return 0;
292
172
        }
293
172
294
172
        self.group_values[0].len()
295
189
    }
296
297
15
    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
298
15
        let mut output = match emit_to {
299
            EmitTo::All => {
300
11
                let group_values = std::mem::take(&mut self.group_values);
301
11
                debug_assert!(self.group_values.is_empty());
302
303
11
                group_values
304
11
                    .into_iter()
305
23
                    .map(|v| v.build())
306
11
                    .collect::<Vec<_>>()
307
            }
308
4
            EmitTo::First(n) => {
309
4
                let output = self
310
4
                    .group_values
311
4
                    .iter_mut()
312
8
                    .map(|v| v.take_n(n))
313
4
                    .collect::<Vec<_>>();
314
315
                // SAFETY: self.map outlives iterator and is not modified concurrently
316
                unsafe {
317
46
                    for bucket in 
self.map.iter()4
{
318
                        // Decrement group index by n
319
46
                        match bucket.as_ref().1.checked_sub(n) {
320
                            // Group index was >= n, shift value down
321
6
                            Some(sub) => bucket.as_mut().1 = sub,
322
                            // Group index was < n, so remove from table
323
40
                            None => self.map.erase(bucket),
324
                        }
325
                    }
326
                }
327
328
4
                output
329
            }
330
        };
331
332
        // TODO: Materialize dictionaries in group keys (#7647)
333
31
        for (field, array) in 
self.schema.fields.iter().zip(&mut output)15
{
334
31
            let expected = field.data_type();
335
31
            if let DataType::Dictionary(_, 
v0
) = expected {
336
0
                let actual = array.data_type();
337
0
                if v.as_ref() != actual {
338
0
                    return Err(DataFusionError::Internal(format!(
339
0
                        "Converted group rows expected dictionary of {v} got {actual}"
340
0
                    )));
341
0
                }
342
0
                *array = cast(array.as_ref(), expected)?;
343
31
            }
344
        }
345
346
15
        Ok(output)
347
15
    }
348
349
13
    fn clear_shrink(&mut self, batch: &RecordBatch) {
350
13
        let count = batch.num_rows();
351
13
        self.group_values.clear();
352
13
        self.map.clear();
353
13
        self.map.shrink_to(count, |_| 
00
); // hasher does not matter since the map is cleared
354
13
        self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>();
355
13
        self.hashes_buffer.clear();
356
13
        self.hashes_buffer.shrink_to(count);
357
13
    }
358
}