/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/approx_distinct.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines physical expressions that can evaluated at runtime during query execution |
19 | | |
20 | | use crate::hyperloglog::HyperLogLog; |
21 | | use arrow::array::BinaryArray; |
22 | | use arrow::array::{ |
23 | | GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, |
24 | | }; |
25 | | use arrow::datatypes::{ |
26 | | ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, |
27 | | UInt32Type, UInt64Type, UInt8Type, |
28 | | }; |
29 | | use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; |
30 | | use datafusion_common::ScalarValue; |
31 | | use datafusion_common::{ |
32 | | downcast_value, internal_err, not_impl_err, DataFusionError, Result, |
33 | | }; |
34 | | use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; |
35 | | use datafusion_expr::utils::format_state_name; |
36 | | use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; |
37 | | use std::any::Any; |
38 | | use std::fmt::{Debug, Formatter}; |
39 | | use std::hash::Hash; |
40 | | use std::marker::PhantomData; |
41 | | make_udaf_expr_and_func!( |
42 | | ApproxDistinct, |
43 | | approx_distinct, |
44 | | expression, |
45 | | "approximate number of distinct input values", |
46 | | approx_distinct_udaf |
47 | | ); |
48 | | |
49 | | impl<T: Hash> From<&HyperLogLog<T>> for ScalarValue { |
50 | 0 | fn from(v: &HyperLogLog<T>) -> ScalarValue { |
51 | 0 | let values = v.as_ref().to_vec(); |
52 | 0 | ScalarValue::Binary(Some(values)) |
53 | 0 | } |
54 | | } |
55 | | |
56 | | impl<T: Hash> TryFrom<&[u8]> for HyperLogLog<T> { |
57 | | type Error = DataFusionError; |
58 | 0 | fn try_from(v: &[u8]) -> Result<HyperLogLog<T>> { |
59 | 0 | let arr: [u8; 16384] = v.try_into().map_err(|_| { |
60 | 0 | DataFusionError::Internal( |
61 | 0 | "Impossibly got invalid binary array from states".into(), |
62 | 0 | ) |
63 | 0 | })?; |
64 | 0 | Ok(HyperLogLog::<T>::new_with_registers(arr)) |
65 | 0 | } |
66 | | } |
67 | | |
68 | | impl<T: Hash> TryFrom<&ScalarValue> for HyperLogLog<T> { |
69 | | type Error = DataFusionError; |
70 | 0 | fn try_from(v: &ScalarValue) -> Result<HyperLogLog<T>> { |
71 | 0 | if let ScalarValue::Binary(Some(slice)) = v { |
72 | 0 | slice.as_slice().try_into() |
73 | | } else { |
74 | 0 | internal_err!( |
75 | 0 | "Impossibly got invalid scalar value while converting to HyperLogLog" |
76 | 0 | ) |
77 | | } |
78 | 0 | } |
79 | | } |
80 | | |
81 | | #[derive(Debug)] |
82 | | struct NumericHLLAccumulator<T> |
83 | | where |
84 | | T: ArrowPrimitiveType, |
85 | | T::Native: Hash, |
86 | | { |
87 | | hll: HyperLogLog<T::Native>, |
88 | | } |
89 | | |
90 | | impl<T> NumericHLLAccumulator<T> |
91 | | where |
92 | | T: ArrowPrimitiveType, |
93 | | T::Native: Hash, |
94 | | { |
95 | | /// new approx_distinct accumulator |
96 | 0 | pub fn new() -> Self { |
97 | 0 | Self { |
98 | 0 | hll: HyperLogLog::new(), |
99 | 0 | } |
100 | 0 | } |
101 | | } |
102 | | |
103 | | #[derive(Debug)] |
104 | | struct StringHLLAccumulator<T> |
105 | | where |
106 | | T: OffsetSizeTrait, |
107 | | { |
108 | | hll: HyperLogLog<String>, |
109 | | phantom_data: PhantomData<T>, |
110 | | } |
111 | | |
112 | | impl<T> StringHLLAccumulator<T> |
113 | | where |
114 | | T: OffsetSizeTrait, |
115 | | { |
116 | | /// new approx_distinct accumulator |
117 | 0 | pub fn new() -> Self { |
118 | 0 | Self { |
119 | 0 | hll: HyperLogLog::new(), |
120 | 0 | phantom_data: PhantomData, |
121 | 0 | } |
122 | 0 | } |
123 | | } |
124 | | |
125 | | #[derive(Debug)] |
126 | | struct BinaryHLLAccumulator<T> |
127 | | where |
128 | | T: OffsetSizeTrait, |
129 | | { |
130 | | hll: HyperLogLog<Vec<u8>>, |
131 | | phantom_data: PhantomData<T>, |
132 | | } |
133 | | |
134 | | impl<T> BinaryHLLAccumulator<T> |
135 | | where |
136 | | T: OffsetSizeTrait, |
137 | | { |
138 | | /// new approx_distinct accumulator |
139 | 0 | pub fn new() -> Self { |
140 | 0 | Self { |
141 | 0 | hll: HyperLogLog::new(), |
142 | 0 | phantom_data: PhantomData, |
143 | 0 | } |
144 | 0 | } |
145 | | } |
146 | | |
147 | | macro_rules! default_accumulator_impl { |
148 | | () => { |
149 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
150 | 0 | assert_eq!(1, states.len(), "expect only 1 element in the states"); |
151 | 0 | let binary_array = downcast_value!(states[0], BinaryArray); |
152 | 0 | for v in binary_array.iter() { |
153 | 0 | let v = v.ok_or_else(|| { |
154 | 0 | DataFusionError::Internal( |
155 | 0 | "Impossibly got empty binary array from states".into(), |
156 | 0 | ) |
157 | 0 | })?; |
158 | 0 | let other = v.try_into()?; |
159 | 0 | self.hll.merge(&other); |
160 | | } |
161 | 0 | Ok(()) |
162 | 0 | } |
163 | | |
164 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
165 | 0 | let value = ScalarValue::from(&self.hll); |
166 | 0 | Ok(vec![value]) |
167 | 0 | } |
168 | | |
169 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
170 | 0 | Ok(ScalarValue::UInt64(Some(self.hll.count() as u64))) |
171 | 0 | } |
172 | | |
173 | 0 | fn size(&self) -> usize { |
174 | 0 | // HLL has static size |
175 | 0 | std::mem::size_of_val(self) |
176 | 0 | } |
177 | | }; |
178 | | } |
179 | | |
180 | | impl<T> Accumulator for BinaryHLLAccumulator<T> |
181 | | where |
182 | | T: OffsetSizeTrait, |
183 | | { |
184 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
185 | 0 | let array: &GenericBinaryArray<T> = |
186 | 0 | downcast_value!(values[0], GenericBinaryArray, T); |
187 | | // flatten because we would skip nulls |
188 | 0 | self.hll |
189 | 0 | .extend(array.into_iter().flatten().map(|v| v.to_vec())); |
190 | 0 | Ok(()) |
191 | 0 | } |
192 | | |
193 | | default_accumulator_impl!(); |
194 | | } |
195 | | |
196 | | impl<T> Accumulator for StringHLLAccumulator<T> |
197 | | where |
198 | | T: OffsetSizeTrait, |
199 | | { |
200 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
201 | 0 | let array: &GenericStringArray<T> = |
202 | 0 | downcast_value!(values[0], GenericStringArray, T); |
203 | | // flatten because we would skip nulls |
204 | 0 | self.hll |
205 | 0 | .extend(array.into_iter().flatten().map(|i| i.to_string())); |
206 | 0 | Ok(()) |
207 | 0 | } |
208 | | |
209 | | default_accumulator_impl!(); |
210 | | } |
211 | | |
212 | | impl<T> Accumulator for NumericHLLAccumulator<T> |
213 | | where |
214 | | T: ArrowPrimitiveType + Debug, |
215 | | T::Native: Hash, |
216 | | { |
217 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
218 | 0 | let array: &PrimitiveArray<T> = downcast_value!(values[0], PrimitiveArray, T); |
219 | | // flatten because we would skip nulls |
220 | 0 | self.hll.extend(array.into_iter().flatten()); |
221 | 0 | Ok(()) |
222 | 0 | } |
223 | | |
224 | | default_accumulator_impl!(); |
225 | | } |
226 | | |
227 | | impl Debug for ApproxDistinct { |
228 | 0 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
229 | 0 | f.debug_struct("ApproxDistinct") |
230 | 0 | .field("name", &self.name()) |
231 | 0 | .field("signature", &self.signature) |
232 | 0 | .finish() |
233 | 0 | } |
234 | | } |
235 | | |
236 | | impl Default for ApproxDistinct { |
237 | 0 | fn default() -> Self { |
238 | 0 | Self::new() |
239 | 0 | } |
240 | | } |
241 | | |
242 | | pub struct ApproxDistinct { |
243 | | signature: Signature, |
244 | | } |
245 | | |
246 | | impl ApproxDistinct { |
247 | 0 | pub fn new() -> Self { |
248 | 0 | Self { |
249 | 0 | signature: Signature::any(1, Volatility::Immutable), |
250 | 0 | } |
251 | 0 | } |
252 | | } |
253 | | |
254 | | impl AggregateUDFImpl for ApproxDistinct { |
255 | 0 | fn as_any(&self) -> &dyn Any { |
256 | 0 | self |
257 | 0 | } |
258 | | |
259 | 0 | fn name(&self) -> &str { |
260 | 0 | "approx_distinct" |
261 | 0 | } |
262 | | |
263 | 0 | fn signature(&self) -> &Signature { |
264 | 0 | &self.signature |
265 | 0 | } |
266 | | |
267 | 0 | fn return_type(&self, _: &[DataType]) -> Result<DataType> { |
268 | 0 | Ok(DataType::UInt64) |
269 | 0 | } |
270 | | |
271 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
272 | 0 | Ok(vec![Field::new( |
273 | 0 | format_state_name(args.name, "hll_registers"), |
274 | 0 | DataType::Binary, |
275 | 0 | false, |
276 | 0 | )]) |
277 | 0 | } |
278 | | |
279 | 0 | fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
280 | 0 | let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; |
281 | | |
282 | 0 | let accumulator: Box<dyn Accumulator> = match data_type { |
283 | | // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL |
284 | | // TODO support for boolean (trivial case) |
285 | | // https://github.com/apache/datafusion/issues/1109 |
286 | 0 | DataType::UInt8 => Box::new(NumericHLLAccumulator::<UInt8Type>::new()), |
287 | 0 | DataType::UInt16 => Box::new(NumericHLLAccumulator::<UInt16Type>::new()), |
288 | 0 | DataType::UInt32 => Box::new(NumericHLLAccumulator::<UInt32Type>::new()), |
289 | 0 | DataType::UInt64 => Box::new(NumericHLLAccumulator::<UInt64Type>::new()), |
290 | 0 | DataType::Int8 => Box::new(NumericHLLAccumulator::<Int8Type>::new()), |
291 | 0 | DataType::Int16 => Box::new(NumericHLLAccumulator::<Int16Type>::new()), |
292 | 0 | DataType::Int32 => Box::new(NumericHLLAccumulator::<Int32Type>::new()), |
293 | 0 | DataType::Int64 => Box::new(NumericHLLAccumulator::<Int64Type>::new()), |
294 | 0 | DataType::Utf8 => Box::new(StringHLLAccumulator::<i32>::new()), |
295 | 0 | DataType::LargeUtf8 => Box::new(StringHLLAccumulator::<i64>::new()), |
296 | 0 | DataType::Binary => Box::new(BinaryHLLAccumulator::<i32>::new()), |
297 | 0 | DataType::LargeBinary => Box::new(BinaryHLLAccumulator::<i64>::new()), |
298 | 0 | other => { |
299 | 0 | return not_impl_err!( |
300 | 0 | "Support for 'approx_distinct' for data type {other} is not implemented" |
301 | 0 | ) |
302 | | } |
303 | | }; |
304 | 0 | Ok(accumulator) |
305 | 0 | } |
306 | | } |