/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Specialized implementation of `COUNT DISTINCT` for "Native" arrays such as |
19 | | //! [`Int64Array`] and [`Float64Array`] |
20 | | //! |
21 | | //! [`Int64Array`]: arrow::array::Int64Array |
22 | | //! [`Float64Array`]: arrow::array::Float64Array |
23 | | use std::collections::HashSet; |
24 | | use std::fmt::Debug; |
25 | | use std::hash::Hash; |
26 | | use std::sync::Arc; |
27 | | |
28 | | use ahash::RandomState; |
29 | | use arrow::array::types::ArrowPrimitiveType; |
30 | | use arrow::array::ArrayRef; |
31 | | use arrow::array::PrimitiveArray; |
32 | | use arrow::datatypes::DataType; |
33 | | |
34 | | use datafusion_common::cast::{as_list_array, as_primitive_array}; |
35 | | use datafusion_common::utils::array_into_list_array_nullable; |
36 | | use datafusion_common::utils::memory::estimate_memory_size; |
37 | | use datafusion_common::ScalarValue; |
38 | | use datafusion_expr_common::accumulator::Accumulator; |
39 | | |
40 | | use crate::utils::Hashable; |
41 | | |
42 | | #[derive(Debug)] |
43 | | pub struct PrimitiveDistinctCountAccumulator<T> |
44 | | where |
45 | | T: ArrowPrimitiveType + Send, |
46 | | T::Native: Eq + Hash, |
47 | | { |
48 | | values: HashSet<T::Native, RandomState>, |
49 | | data_type: DataType, |
50 | | } |
51 | | |
52 | | impl<T> PrimitiveDistinctCountAccumulator<T> |
53 | | where |
54 | | T: ArrowPrimitiveType + Send, |
55 | | T::Native: Eq + Hash, |
56 | | { |
57 | 0 | pub fn new(data_type: &DataType) -> Self { |
58 | 0 | Self { |
59 | 0 | values: HashSet::default(), |
60 | 0 | data_type: data_type.clone(), |
61 | 0 | } |
62 | 0 | } |
63 | | } |
64 | | |
65 | | impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T> |
66 | | where |
67 | | T: ArrowPrimitiveType + Send + Debug, |
68 | | T::Native: Eq + Hash, |
69 | | { |
70 | 0 | fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> { |
71 | 0 | let arr = Arc::new( |
72 | 0 | PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned()) |
73 | 0 | .with_data_type(self.data_type.clone()), |
74 | 0 | ); |
75 | 0 | let list = Arc::new(array_into_list_array_nullable(arr)); |
76 | 0 | Ok(vec![ScalarValue::List(list)]) |
77 | 0 | } |
78 | | |
79 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { |
80 | 0 | if values.is_empty() { |
81 | 0 | return Ok(()); |
82 | 0 | } |
83 | | |
84 | 0 | let arr = as_primitive_array::<T>(&values[0])?; |
85 | 0 | arr.iter().for_each(|value| { |
86 | 0 | if let Some(value) = value { |
87 | 0 | self.values.insert(value); |
88 | 0 | } |
89 | 0 | }); |
90 | 0 |
|
91 | 0 | Ok(()) |
92 | 0 | } |
93 | | |
94 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { |
95 | 0 | if states.is_empty() { |
96 | 0 | return Ok(()); |
97 | 0 | } |
98 | 0 | assert_eq!( |
99 | 0 | states.len(), |
100 | | 1, |
101 | 0 | "count_distinct states must be single array" |
102 | | ); |
103 | | |
104 | 0 | let arr = as_list_array(&states[0])?; |
105 | 0 | arr.iter().try_for_each(|maybe_list| { |
106 | 0 | if let Some(list) = maybe_list { |
107 | 0 | let list = as_primitive_array::<T>(&list)?; |
108 | 0 | self.values.extend(list.values()) |
109 | 0 | }; |
110 | 0 | Ok(()) |
111 | 0 | }) |
112 | 0 | } |
113 | | |
114 | 0 | fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> { |
115 | 0 | Ok(ScalarValue::Int64(Some(self.values.len() as i64))) |
116 | 0 | } |
117 | | |
118 | 0 | fn size(&self) -> usize { |
119 | 0 | let num_elements = self.values.len(); |
120 | 0 | let fixed_size = |
121 | 0 | std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); |
122 | 0 |
|
123 | 0 | estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap() |
124 | 0 | } |
125 | | } |
126 | | |
127 | | #[derive(Debug)] |
128 | | pub struct FloatDistinctCountAccumulator<T> |
129 | | where |
130 | | T: ArrowPrimitiveType + Send, |
131 | | { |
132 | | values: HashSet<Hashable<T::Native>, RandomState>, |
133 | | } |
134 | | |
135 | | impl<T> FloatDistinctCountAccumulator<T> |
136 | | where |
137 | | T: ArrowPrimitiveType + Send, |
138 | | { |
139 | 0 | pub fn new() -> Self { |
140 | 0 | Self { |
141 | 0 | values: HashSet::default(), |
142 | 0 | } |
143 | 0 | } |
144 | | } |
145 | | |
146 | | impl<T> Default for FloatDistinctCountAccumulator<T> |
147 | | where |
148 | | T: ArrowPrimitiveType + Send, |
149 | | { |
150 | 0 | fn default() -> Self { |
151 | 0 | Self::new() |
152 | 0 | } |
153 | | } |
154 | | |
155 | | impl<T> Accumulator for FloatDistinctCountAccumulator<T> |
156 | | where |
157 | | T: ArrowPrimitiveType + Send + Debug, |
158 | | { |
159 | 0 | fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> { |
160 | 0 | let arr = Arc::new(PrimitiveArray::<T>::from_iter_values( |
161 | 0 | self.values.iter().map(|v| v.0), |
162 | 0 | )) as ArrayRef; |
163 | 0 | let list = Arc::new(array_into_list_array_nullable(arr)); |
164 | 0 | Ok(vec![ScalarValue::List(list)]) |
165 | 0 | } |
166 | | |
167 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { |
168 | 0 | if values.is_empty() { |
169 | 0 | return Ok(()); |
170 | 0 | } |
171 | | |
172 | 0 | let arr = as_primitive_array::<T>(&values[0])?; |
173 | 0 | arr.iter().for_each(|value| { |
174 | 0 | if let Some(value) = value { |
175 | 0 | self.values.insert(Hashable(value)); |
176 | 0 | } |
177 | 0 | }); |
178 | 0 |
|
179 | 0 | Ok(()) |
180 | 0 | } |
181 | | |
182 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { |
183 | 0 | if states.is_empty() { |
184 | 0 | return Ok(()); |
185 | 0 | } |
186 | 0 | assert_eq!( |
187 | 0 | states.len(), |
188 | | 1, |
189 | 0 | "count_distinct states must be single array" |
190 | | ); |
191 | | |
192 | 0 | let arr = as_list_array(&states[0])?; |
193 | 0 | arr.iter().try_for_each(|maybe_list| { |
194 | 0 | if let Some(list) = maybe_list { |
195 | 0 | let list = as_primitive_array::<T>(&list)?; |
196 | 0 | self.values |
197 | 0 | .extend(list.values().iter().map(|v| Hashable(*v))); |
198 | 0 | }; |
199 | 0 | Ok(()) |
200 | 0 | }) |
201 | 0 | } |
202 | | |
203 | 0 | fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> { |
204 | 0 | Ok(ScalarValue::Int64(Some(self.values.len() as i64))) |
205 | 0 | } |
206 | | |
207 | 0 | fn size(&self) -> usize { |
208 | 0 | let num_elements = self.values.len(); |
209 | 0 | let fixed_size = |
210 | 0 | std::mem::size_of_val(self) + std::mem::size_of_val(&self.values); |
211 | 0 |
|
212 | 0 | estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap() |
213 | 0 | } |
214 | | } |