Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-plan/src/aggregates/group_values/row.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use crate::aggregates::group_values::GroupValues;
19
use ahash::RandomState;
20
use arrow::compute::cast;
21
use arrow::record_batch::RecordBatch;
22
use arrow::row::{RowConverter, Rows, SortField};
23
use arrow_array::{Array, ArrayRef, ListArray, StructArray};
24
use arrow_schema::{DataType, SchemaRef};
25
use datafusion_common::hash_utils::create_hashes;
26
use datafusion_common::Result;
27
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
28
use datafusion_expr::EmitTo;
29
use hashbrown::raw::RawTable;
30
use std::sync::Arc;
31
32
/// A [`GroupValues`] making use of [`Rows`]
33
///
34
/// This is a general implementation of [`GroupValues`] that works for any
35
/// combination of data types and number of columns, including nested types such as
36
/// structs and lists.
37
///
38
/// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise
39
/// representation.
40
pub struct GroupValuesRows {
41
    /// The output schema
42
    schema: SchemaRef,
43
44
    /// Converter for the group values
45
    row_converter: RowConverter,
46
47
    /// Logically maps group values to a group_index in
48
    /// [`Self::group_values`] and in each accumulator
49
    ///
50
    /// Uses the raw API of hashbrown to avoid actually storing the
51
    /// keys (group values) in the table
52
    ///
53
    /// keys: u64 hashes of the GroupValue
54
    /// values: (hash, group_index)
55
    map: RawTable<(u64, usize)>,
56
57
    /// The size of `map` in bytes
58
    map_size: usize,
59
60
    /// The actual group by values, stored in arrow [`Row`] format.
61
    /// `group_values[i]` holds the group value for group_index `i`.
62
    ///
63
    /// The row format is used to compare group keys quickly and store
64
    /// them efficiently in memory. Quick comparison is especially
65
    /// important for multi-column group keys.
66
    ///
67
    /// [`Row`]: arrow::row::Row
68
    group_values: Option<Rows>,
69
70
    /// reused buffer to store hashes
71
    hashes_buffer: Vec<u64>,
72
73
    /// reused buffer to store rows
74
    rows_buffer: Rows,
75
76
    /// Random state for creating hashes
77
    random_state: RandomState,
78
}
79
80
impl GroupValuesRows {
81
1
    pub fn try_new(schema: SchemaRef) -> Result<Self> {
82
1
        let row_converter = RowConverter::new(
83
1
            schema
84
1
                .fields()
85
1
                .iter()
86
1
                .map(|f| SortField::new(f.data_type().clone()))
87
1
                .collect(),
88
1
        )
?0
;
89
90
1
        let map = RawTable::with_capacity(0);
91
1
92
1
        let starting_rows_capacity = 1000;
93
1
94
1
        let starting_data_capacity = 64 * starting_rows_capacity;
95
1
        let rows_buffer =
96
1
            row_converter.empty_rows(starting_rows_capacity, starting_data_capacity);
97
1
        Ok(Self {
98
1
            schema,
99
1
            row_converter,
100
1
            map,
101
1
            map_size: 0,
102
1
            group_values: None,
103
1
            hashes_buffer: Default::default(),
104
1
            rows_buffer,
105
1
            random_state: Default::default(),
106
1
        })
107
1
    }
108
}
109
110
impl GroupValues for GroupValuesRows {
111
1
    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
112
1
        // Convert the group keys into the row format
113
1
        let group_rows = &mut self.rows_buffer;
114
1
        group_rows.clear();
115
1
        self.row_converter.append(group_rows, cols)
?0
;
116
1
        let n_rows = group_rows.num_rows();
117
118
1
        let mut group_values = match self.group_values.take() {
119
0
            Some(group_values) => group_values,
120
1
            None => self.row_converter.empty_rows(0, 0),
121
        };
122
123
        // tracks to which group each of the input rows belongs
124
1
        groups.clear();
125
1
126
1
        // 1.1 Calculate the group keys for the group values
127
1
        let batch_hashes = &mut self.hashes_buffer;
128
1
        batch_hashes.clear();
129
1
        batch_hashes.resize(n_rows, 0);
130
1
        create_hashes(cols, &self.random_state, batch_hashes)
?0
;
131
132
3
        for (row, &target_hash) in 
batch_hashes.iter().enumerate()1
{
133
3
            let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| {
134
1
                // Somewhat surprisingly, this closure can be called even if the
135
1
                // hash doesn't match, so check the hash first with an integer
136
1
                // comparison first avoid the more expensive comparison with
137
1
                // group value. https://github.com/apache/datafusion/pull/11718
138
1
                target_hash == *exist_hash
139
                    // verify that the group that we are inserting with hash is
140
                    // actually the same key value as the group in
141
                    // existing_idx  (aka group_values @ row)
142
1
                    && group_rows.row(row) == group_values.row(*group_idx)
143
3
            
}1
);
144
145
3
            let group_idx = match entry {
146
                // Existing group_index for this group value
147
1
                Some((_hash, group_idx)) => *group_idx,
148
                //  1.2 Need to create new entry for the group
149
                None => {
150
                    // Add new entry to aggr_state and save newly created index
151
2
                    let group_idx = group_values.num_rows();
152
2
                    group_values.push(group_rows.row(row));
153
2
154
2
                    // for hasher function, use precomputed hash value
155
2
                    self.map.insert_accounted(
156
2
                        (target_hash, group_idx),
157
2
                        |(hash, _group_index)| *hash,
158
2
                        &mut self.map_size,
159
2
                    );
160
2
                    group_idx
161
                }
162
            };
163
3
            groups.push(group_idx);
164
        }
165
166
1
        self.group_values = Some(group_values);
167
1
168
1
        Ok(())
169
1
    }
170
171
3
    fn size(&self) -> usize {
172
3
        let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0);
173
3
        self.row_converter.size()
174
3
            + group_values_size
175
3
            + self.map_size
176
3
            + self.rows_buffer.size()
177
3
            + self.hashes_buffer.allocated_size()
178
3
    }
179
180
1
    fn is_empty(&self) -> bool {
181
1
        self.len() == 0
182
1
    }
183
184
4
    fn len(&self) -> usize {
185
4
        self.group_values
186
4
            .as_ref()
187
4
            .map(|group_values| 
group_values.num_rows()2
)
188
4
            .unwrap_or(0)
189
4
    }
190
191
1
    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
192
1
        let mut group_values = self
193
1
            .group_values
194
1
            .take()
195
1
            .expect("Can not emit from empty rows");
196
197
1
        let mut output = match emit_to {
198
            EmitTo::All => {
199
1
                let output = self.row_converter.convert_rows(&group_values)
?0
;
200
1
                group_values.clear();
201
1
                output
202
            }
203
0
            EmitTo::First(n) => {
204
0
                let groups_rows = group_values.iter().take(n);
205
0
                let output = self.row_converter.convert_rows(groups_rows)?;
206
                // Clear out first n group keys by copying them to a new Rows.
207
                // TODO file some ticket in arrow-rs to make this more efficient?
208
0
                let mut new_group_values = self.row_converter.empty_rows(0, 0);
209
0
                for row in group_values.iter().skip(n) {
210
0
                    new_group_values.push(row);
211
0
                }
212
0
                std::mem::swap(&mut new_group_values, &mut group_values);
213
214
                // SAFETY: self.map outlives iterator and is not modified concurrently
215
                unsafe {
216
0
                    for bucket in self.map.iter() {
217
                        // Decrement group index by n
218
0
                        match bucket.as_ref().1.checked_sub(n) {
219
                            // Group index was >= n, shift value down
220
0
                            Some(sub) => bucket.as_mut().1 = sub,
221
                            // Group index was < n, so remove from table
222
0
                            None => self.map.erase(bucket),
223
                        }
224
                    }
225
                }
226
0
                output
227
            }
228
        };
229
230
        // TODO: Materialize dictionaries in group keys
231
        // https://github.com/apache/datafusion/issues/7647
232
1
        for (field, array) in self.schema.fields.iter().zip(&mut output) {
233
1
            let expected = field.data_type();
234
1
            *array = dictionary_encode_if_necessary(
235
1
                Arc::<dyn arrow_array::Array>::clone(array),
236
1
                expected,
237
1
            )
?0
;
238
        }
239
240
1
        self.group_values = Some(group_values);
241
1
        Ok(output)
242
1
    }
243
244
1
    fn clear_shrink(&mut self, batch: &RecordBatch) {
245
1
        let count = batch.num_rows();
246
1
        self.group_values = self.group_values.take().map(|mut rows| {
247
1
            rows.clear();
248
1
            rows
249
1
        });
250
1
        self.map.clear();
251
1
        self.map.shrink_to(count, |_| 
00
); // hasher does not matter since the map is cleared
252
1
        self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>();
253
1
        self.hashes_buffer.clear();
254
1
        self.hashes_buffer.shrink_to(count);
255
1
    }
256
}
257
258
3
fn dictionary_encode_if_necessary(
259
3
    array: ArrayRef,
260
3
    expected: &DataType,
261
3
) -> Result<ArrayRef> {
262
3
    match (expected, array.data_type()) {
263
1
        (DataType::Struct(expected_fields), _) => {
264
1
            let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
265
1
            let arrays = expected_fields
266
1
                .iter()
267
1
                .zip(struct_array.columns())
268
2
                .map(|(expected_field, column)| {
269
2
                    dictionary_encode_if_necessary(
270
2
                        Arc::<dyn arrow_array::Array>::clone(column),
271
2
                        expected_field.data_type(),
272
2
                    )
273
2
                })
274
1
                .collect::<Result<Vec<_>>>()
?0
;
275
276
1
            Ok(Arc::new(StructArray::try_new(
277
1
                expected_fields.clone(),
278
1
                arrays,
279
1
                struct_array.nulls().cloned(),
280
1
            )
?0
))
281
        }
282
0
        (DataType::List(expected_field), &DataType::List(_)) => {
283
0
            let list = array.as_any().downcast_ref::<ListArray>().unwrap();
284
0
285
0
            Ok(Arc::new(ListArray::try_new(
286
0
                Arc::<arrow_schema::Field>::clone(expected_field),
287
0
                list.offsets().clone(),
288
0
                dictionary_encode_if_necessary(
289
0
                    Arc::<dyn arrow_array::Array>::clone(list.values()),
290
0
                    expected_field.data_type(),
291
0
                )?,
292
0
                list.nulls().cloned(),
293
0
            )?))
294
        }
295
2
        (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)
?0
),
296
0
        (_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)),
297
    }
298
3
}