/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::sync::Arc; |
19 | | |
20 | | use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}; |
21 | | use arrow::buffer::NullBuffer; |
22 | | use arrow::compute; |
23 | | use arrow::datatypes::ArrowPrimitiveType; |
24 | | use arrow::datatypes::DataType; |
25 | | use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; |
26 | | use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; |
27 | | |
28 | | use super::accumulate::NullState; |
29 | | |
30 | | /// An accumulator that implements a single operation over |
31 | | /// [`ArrowPrimitiveType`] where the accumulated state is the same as |
32 | | /// the input type (such as `Sum`) |
33 | | /// |
34 | | /// F: The function to apply to two elements. The first argument is |
35 | | /// the existing value and should be updated with the second value |
36 | | /// (e.g. [`BitAndAssign`] style). |
37 | | /// |
38 | | /// [`BitAndAssign`]: std::ops::BitAndAssign |
39 | | #[derive(Debug)] |
40 | | pub struct PrimitiveGroupsAccumulator<T, F> |
41 | | where |
42 | | T: ArrowPrimitiveType + Send, |
43 | | F: Fn(&mut T::Native, T::Native) + Send + Sync, |
44 | | { |
45 | | /// values per group, stored as the native type |
46 | | values: Vec<T::Native>, |
47 | | |
48 | | /// The output type (needed for Decimal precision and scale) |
49 | | data_type: DataType, |
50 | | |
51 | | /// The starting value for new groups |
52 | | starting_value: T::Native, |
53 | | |
54 | | /// Track nulls in the input / filters |
55 | | null_state: NullState, |
56 | | |
57 | | /// Function that computes the primitive result |
58 | | prim_fn: F, |
59 | | } |
60 | | |
61 | | impl<T, F> PrimitiveGroupsAccumulator<T, F> |
62 | | where |
63 | | T: ArrowPrimitiveType + Send, |
64 | | F: Fn(&mut T::Native, T::Native) + Send + Sync, |
65 | | { |
66 | 1 | pub fn new(data_type: &DataType, prim_fn: F) -> Self { |
67 | 1 | Self { |
68 | 1 | values: vec![], |
69 | 1 | data_type: data_type.clone(), |
70 | 1 | null_state: NullState::new(), |
71 | 1 | starting_value: T::default_value(), |
72 | 1 | prim_fn, |
73 | 1 | } |
74 | 1 | } |
75 | | |
76 | | /// Set the starting values for new groups |
77 | 0 | pub fn with_starting_value(mut self, starting_value: T::Native) -> Self { |
78 | 0 | self.starting_value = starting_value; |
79 | 0 | self |
80 | 0 | } |
81 | | } |
82 | | |
83 | | impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F> |
84 | | where |
85 | | T: ArrowPrimitiveType + Send, |
86 | | F: Fn(&mut T::Native, T::Native) + Send + Sync, |
87 | | { |
88 | 1 | fn update_batch( |
89 | 1 | &mut self, |
90 | 1 | values: &[ArrayRef], |
91 | 1 | group_indices: &[usize], |
92 | 1 | opt_filter: Option<&BooleanArray>, |
93 | 1 | total_num_groups: usize, |
94 | 1 | ) -> Result<()> { |
95 | 1 | assert_eq!(values.len(), 1, "single argument to update_batch"0 ); |
96 | 1 | let values = values[0].as_primitive::<T>(); |
97 | 1 | |
98 | 1 | // update values |
99 | 1 | self.values.resize(total_num_groups, self.starting_value); |
100 | 1 | |
101 | 1 | // NullState dispatches / handles tracking nulls and groups that saw no values |
102 | 1 | self.null_state.accumulate( |
103 | 1 | group_indices, |
104 | 1 | values, |
105 | 1 | opt_filter, |
106 | 1 | total_num_groups, |
107 | 3 | |group_index, new_value| { |
108 | 3 | let value = &mut self.values[group_index]; |
109 | 3 | (self.prim_fn)(value, new_value); |
110 | 3 | }, |
111 | 1 | ); |
112 | 1 | |
113 | 1 | Ok(()) |
114 | 1 | } |
115 | | |
116 | 1 | fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> { |
117 | 1 | let values = emit_to.take_needed(&mut self.values); |
118 | 1 | let nulls = self.null_state.build(emit_to); |
119 | 1 | let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy |
120 | 1 | .with_data_type(self.data_type.clone()); |
121 | 1 | Ok(Arc::new(values)) |
122 | 1 | } |
123 | | |
124 | 0 | fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> { |
125 | 0 | self.evaluate(emit_to).map(|arr| vec![arr]) |
126 | 0 | } |
127 | | |
128 | 1 | fn merge_batch( |
129 | 1 | &mut self, |
130 | 1 | values: &[ArrayRef], |
131 | 1 | group_indices: &[usize], |
132 | 1 | opt_filter: Option<&BooleanArray>, |
133 | 1 | total_num_groups: usize, |
134 | 1 | ) -> Result<()> { |
135 | 1 | // update / merge are the same |
136 | 1 | self.update_batch(values, group_indices, opt_filter, total_num_groups) |
137 | 1 | } |
138 | | |
139 | | /// Converts an input batch directly to a state batch |
140 | | /// |
141 | | /// The state is: |
142 | | /// - self.prim_fn for all non null, non filtered values |
143 | | /// - null otherwise |
144 | | /// |
145 | 0 | fn convert_to_state( |
146 | 0 | &self, |
147 | 0 | values: &[ArrayRef], |
148 | 0 | opt_filter: Option<&BooleanArray>, |
149 | 0 | ) -> Result<Vec<ArrayRef>> { |
150 | 0 | let values = values[0].as_primitive::<T>().clone(); |
151 | 0 |
|
152 | 0 | // Initializing state with starting values |
153 | 0 | let initial_state = |
154 | 0 | PrimitiveArray::<T>::from_value(self.starting_value, values.len()); |
155 | | |
156 | | // Recalculating values in case there is filter |
157 | 0 | let values = match opt_filter { |
158 | 0 | None => values, |
159 | 0 | Some(filter) => { |
160 | 0 | let (filter_values, filter_nulls) = filter.clone().into_parts(); |
161 | | // Calculating filter mask as a result of bitand of filter, and converting it to null buffer |
162 | 0 | let filter_bool = match filter_nulls { |
163 | 0 | Some(filter_nulls) => filter_nulls.inner() & &filter_values, |
164 | 0 | None => filter_values, |
165 | | }; |
166 | 0 | let filter_nulls = NullBuffer::from(filter_bool); |
167 | 0 |
|
168 | 0 | // Rebuilding input values with a new nulls mask, which is equal to |
169 | 0 | // the union of original nulls and filter mask |
170 | 0 | let (dt, values_buf, original_nulls) = values.into_parts(); |
171 | 0 | let nulls_buf = |
172 | 0 | NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls)); |
173 | 0 | PrimitiveArray::<T>::new(values_buf, nulls_buf).with_data_type(dt) |
174 | | } |
175 | | }; |
176 | | |
177 | 0 | let state_values = compute::binary_mut(initial_state, &values, |mut x, y| { |
178 | 0 | (self.prim_fn)(&mut x, y); |
179 | 0 | x |
180 | 0 | }); |
181 | 0 | let state_values = state_values |
182 | 0 | .map_err(|_| { |
183 | 0 | internal_datafusion_err!( |
184 | 0 | "initial_values underlying buffer must not be shared" |
185 | 0 | ) |
186 | 0 | })? |
187 | 0 | .map_err(DataFusionError::from)? |
188 | 0 | .with_data_type(self.data_type.clone()); |
189 | 0 |
|
190 | 0 | Ok(vec![Arc::new(state_values)]) |
191 | 0 | } |
192 | | |
193 | 0 | fn supports_convert_to_state(&self) -> bool { |
194 | 0 | true |
195 | 0 | } |
196 | | |
197 | 3 | fn size(&self) -> usize { |
198 | 3 | self.values.capacity() * std::mem::size_of::<T::Native>() + self.null_state.size() |
199 | 3 | } |
200 | | } |