Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Specialized implementation of `COUNT DISTINCT` for "Native" arrays such as
19
//! [`Int64Array`] and [`Float64Array`]
20
//!
21
//! [`Int64Array`]: arrow::array::Int64Array
22
//! [`Float64Array`]: arrow::array::Float64Array
23
use std::collections::HashSet;
24
use std::fmt::Debug;
25
use std::hash::Hash;
26
use std::sync::Arc;
27
28
use ahash::RandomState;
29
use arrow::array::types::ArrowPrimitiveType;
30
use arrow::array::ArrayRef;
31
use arrow::array::PrimitiveArray;
32
use arrow::datatypes::DataType;
33
34
use datafusion_common::cast::{as_list_array, as_primitive_array};
35
use datafusion_common::utils::array_into_list_array_nullable;
36
use datafusion_common::utils::memory::estimate_memory_size;
37
use datafusion_common::ScalarValue;
38
use datafusion_expr_common::accumulator::Accumulator;
39
40
use crate::utils::Hashable;
41
42
#[derive(Debug)]
43
pub struct PrimitiveDistinctCountAccumulator<T>
44
where
45
    T: ArrowPrimitiveType + Send,
46
    T::Native: Eq + Hash,
47
{
48
    values: HashSet<T::Native, RandomState>,
49
    data_type: DataType,
50
}
51
52
impl<T> PrimitiveDistinctCountAccumulator<T>
53
where
54
    T: ArrowPrimitiveType + Send,
55
    T::Native: Eq + Hash,
56
{
57
0
    pub fn new(data_type: &DataType) -> Self {
58
0
        Self {
59
0
            values: HashSet::default(),
60
0
            data_type: data_type.clone(),
61
0
        }
62
0
    }
63
}
64
65
impl<T> Accumulator for PrimitiveDistinctCountAccumulator<T>
66
where
67
    T: ArrowPrimitiveType + Send + Debug,
68
    T::Native: Eq + Hash,
69
{
70
0
    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
71
0
        let arr = Arc::new(
72
0
            PrimitiveArray::<T>::from_iter_values(self.values.iter().cloned())
73
0
                .with_data_type(self.data_type.clone()),
74
0
        );
75
0
        let list = Arc::new(array_into_list_array_nullable(arr));
76
0
        Ok(vec![ScalarValue::List(list)])
77
0
    }
78
79
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
80
0
        if values.is_empty() {
81
0
            return Ok(());
82
0
        }
83
84
0
        let arr = as_primitive_array::<T>(&values[0])?;
85
0
        arr.iter().for_each(|value| {
86
0
            if let Some(value) = value {
87
0
                self.values.insert(value);
88
0
            }
89
0
        });
90
0
91
0
        Ok(())
92
0
    }
93
94
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
95
0
        if states.is_empty() {
96
0
            return Ok(());
97
0
        }
98
0
        assert_eq!(
99
0
            states.len(),
100
            1,
101
0
            "count_distinct states must be single array"
102
        );
103
104
0
        let arr = as_list_array(&states[0])?;
105
0
        arr.iter().try_for_each(|maybe_list| {
106
0
            if let Some(list) = maybe_list {
107
0
                let list = as_primitive_array::<T>(&list)?;
108
0
                self.values.extend(list.values())
109
0
            };
110
0
            Ok(())
111
0
        })
112
0
    }
113
114
0
    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
115
0
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
116
0
    }
117
118
0
    fn size(&self) -> usize {
119
0
        let num_elements = self.values.len();
120
0
        let fixed_size =
121
0
            std::mem::size_of_val(self) + std::mem::size_of_val(&self.values);
122
0
123
0
        estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
124
0
    }
125
}
126
127
#[derive(Debug)]
128
pub struct FloatDistinctCountAccumulator<T>
129
where
130
    T: ArrowPrimitiveType + Send,
131
{
132
    values: HashSet<Hashable<T::Native>, RandomState>,
133
}
134
135
impl<T> FloatDistinctCountAccumulator<T>
136
where
137
    T: ArrowPrimitiveType + Send,
138
{
139
0
    pub fn new() -> Self {
140
0
        Self {
141
0
            values: HashSet::default(),
142
0
        }
143
0
    }
144
}
145
146
impl<T> Default for FloatDistinctCountAccumulator<T>
147
where
148
    T: ArrowPrimitiveType + Send,
149
{
150
0
    fn default() -> Self {
151
0
        Self::new()
152
0
    }
153
}
154
155
impl<T> Accumulator for FloatDistinctCountAccumulator<T>
156
where
157
    T: ArrowPrimitiveType + Send + Debug,
158
{
159
0
    fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
160
0
        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
161
0
            self.values.iter().map(|v| v.0),
162
0
        )) as ArrayRef;
163
0
        let list = Arc::new(array_into_list_array_nullable(arr));
164
0
        Ok(vec![ScalarValue::List(list)])
165
0
    }
166
167
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
168
0
        if values.is_empty() {
169
0
            return Ok(());
170
0
        }
171
172
0
        let arr = as_primitive_array::<T>(&values[0])?;
173
0
        arr.iter().for_each(|value| {
174
0
            if let Some(value) = value {
175
0
                self.values.insert(Hashable(value));
176
0
            }
177
0
        });
178
0
179
0
        Ok(())
180
0
    }
181
182
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
183
0
        if states.is_empty() {
184
0
            return Ok(());
185
0
        }
186
0
        assert_eq!(
187
0
            states.len(),
188
            1,
189
0
            "count_distinct states must be single array"
190
        );
191
192
0
        let arr = as_list_array(&states[0])?;
193
0
        arr.iter().try_for_each(|maybe_list| {
194
0
            if let Some(list) = maybe_list {
195
0
                let list = as_primitive_array::<T>(&list)?;
196
0
                self.values
197
0
                    .extend(list.values().iter().map(|v| Hashable(*v)));
198
0
            };
199
0
            Ok(())
200
0
        })
201
0
    }
202
203
0
    fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
204
0
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
205
0
    }
206
207
0
    fn size(&self) -> usize {
208
0
        let num_elements = self.values.len();
209
0
        let fixed_size =
210
0
            std::mem::size_of_val(self) + std::mem::size_of_val(&self.values);
211
0
212
0
        estimate_memory_size::<T::Native>(num_elements, fixed_size).unwrap()
213
0
    }
214
}