Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::sync::Arc;
19
20
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
21
use arrow::buffer::NullBuffer;
22
use arrow::compute;
23
use arrow::datatypes::ArrowPrimitiveType;
24
use arrow::datatypes::DataType;
25
use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
26
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
27
28
use super::accumulate::NullState;
29
30
/// An accumulator that implements a single operation over
31
/// [`ArrowPrimitiveType`] where the accumulated state is the same as
32
/// the input type (such as `Sum`)
33
///
34
/// F: The function to apply to two elements. The first argument is
35
/// the existing value and should be updated with the second value
36
/// (e.g. [`BitAndAssign`] style).
37
///
38
/// [`BitAndAssign`]: std::ops::BitAndAssign
39
#[derive(Debug)]
40
pub struct PrimitiveGroupsAccumulator<T, F>
41
where
42
    T: ArrowPrimitiveType + Send,
43
    F: Fn(&mut T::Native, T::Native) + Send + Sync,
44
{
45
    /// values per group, stored as the native type
46
    values: Vec<T::Native>,
47
48
    /// The output type (needed for Decimal precision and scale)
49
    data_type: DataType,
50
51
    /// The starting value for new groups
52
    starting_value: T::Native,
53
54
    /// Track nulls in the input / filters
55
    null_state: NullState,
56
57
    /// Function that computes the primitive result
58
    prim_fn: F,
59
}
60
61
impl<T, F> PrimitiveGroupsAccumulator<T, F>
62
where
63
    T: ArrowPrimitiveType + Send,
64
    F: Fn(&mut T::Native, T::Native) + Send + Sync,
65
{
66
1
    pub fn new(data_type: &DataType, prim_fn: F) -> Self {
67
1
        Self {
68
1
            values: vec![],
69
1
            data_type: data_type.clone(),
70
1
            null_state: NullState::new(),
71
1
            starting_value: T::default_value(),
72
1
            prim_fn,
73
1
        }
74
1
    }
75
76
    /// Set the starting values for new groups
77
0
    pub fn with_starting_value(mut self, starting_value: T::Native) -> Self {
78
0
        self.starting_value = starting_value;
79
0
        self
80
0
    }
81
}
82
83
impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
84
where
85
    T: ArrowPrimitiveType + Send,
86
    F: Fn(&mut T::Native, T::Native) + Send + Sync,
87
{
88
1
    fn update_batch(
89
1
        &mut self,
90
1
        values: &[ArrayRef],
91
1
        group_indices: &[usize],
92
1
        opt_filter: Option<&BooleanArray>,
93
1
        total_num_groups: usize,
94
1
    ) -> Result<()> {
95
1
        assert_eq!(values.len(), 1, 
"single argument to update_batch"0
);
96
1
        let values = values[0].as_primitive::<T>();
97
1
98
1
        // update values
99
1
        self.values.resize(total_num_groups, self.starting_value);
100
1
101
1
        // NullState dispatches / handles tracking nulls and groups that saw no values
102
1
        self.null_state.accumulate(
103
1
            group_indices,
104
1
            values,
105
1
            opt_filter,
106
1
            total_num_groups,
107
3
            |group_index, new_value| {
108
3
                let value = &mut self.values[group_index];
109
3
                (self.prim_fn)(value, new_value);
110
3
            },
111
1
        );
112
1
113
1
        Ok(())
114
1
    }
115
116
1
    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
117
1
        let values = emit_to.take_needed(&mut self.values);
118
1
        let nulls = self.null_state.build(emit_to);
119
1
        let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
120
1
            .with_data_type(self.data_type.clone());
121
1
        Ok(Arc::new(values))
122
1
    }
123
124
0
    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
125
0
        self.evaluate(emit_to).map(|arr| vec![arr])
126
0
    }
127
128
1
    fn merge_batch(
129
1
        &mut self,
130
1
        values: &[ArrayRef],
131
1
        group_indices: &[usize],
132
1
        opt_filter: Option<&BooleanArray>,
133
1
        total_num_groups: usize,
134
1
    ) -> Result<()> {
135
1
        // update / merge are the same
136
1
        self.update_batch(values, group_indices, opt_filter, total_num_groups)
137
1
    }
138
139
    /// Converts an input batch directly to a state batch
140
    ///
141
    /// The state is:
142
    /// - self.prim_fn for all non null, non filtered values
143
    /// - null otherwise
144
    ///
145
0
    fn convert_to_state(
146
0
        &self,
147
0
        values: &[ArrayRef],
148
0
        opt_filter: Option<&BooleanArray>,
149
0
    ) -> Result<Vec<ArrayRef>> {
150
0
        let values = values[0].as_primitive::<T>().clone();
151
0
152
0
        // Initializing state with starting values
153
0
        let initial_state =
154
0
            PrimitiveArray::<T>::from_value(self.starting_value, values.len());
155
156
        // Recalculating values in case there is filter
157
0
        let values = match opt_filter {
158
0
            None => values,
159
0
            Some(filter) => {
160
0
                let (filter_values, filter_nulls) = filter.clone().into_parts();
161
                // Calculating filter mask as a result of bitand of filter, and converting it to null buffer
162
0
                let filter_bool = match filter_nulls {
163
0
                    Some(filter_nulls) => filter_nulls.inner() & &filter_values,
164
0
                    None => filter_values,
165
                };
166
0
                let filter_nulls = NullBuffer::from(filter_bool);
167
0
168
0
                // Rebuilding input values with a new nulls mask, which is equal to
169
0
                // the union of original nulls and filter mask
170
0
                let (dt, values_buf, original_nulls) = values.into_parts();
171
0
                let nulls_buf =
172
0
                    NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls));
173
0
                PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt)
174
            }
175
        };
176
177
0
        let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
178
0
            (self.prim_fn)(&mut x, y);
179
0
            x
180
0
        });
181
0
        let state_values = state_values
182
0
            .map_err(|_| {
183
0
                internal_datafusion_err!(
184
0
                    "initial_values underlying buffer must not be shared"
185
0
                )
186
0
            })?
187
0
            .map_err(DataFusionError::from)?
188
0
            .with_data_type(self.data_type.clone());
189
0
190
0
        Ok(vec![Arc::new(state_values)])
191
0
    }
192
193
0
    fn supports_convert_to_state(&self) -> bool {
194
0
        true
195
0
    }
196
197
3
    fn size(&self) -> usize {
198
3
        self.values.capacity() * std::mem::size_of::<T::Native>() + self.null_state.size()
199
3
    }
200
}