Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/correlation.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! [`Correlation`]: correlation sample aggregations.
19
20
use std::any::Any;
21
use std::fmt::Debug;
22
use std::sync::Arc;
23
24
use arrow::compute::{and, filter, is_not_null};
25
use arrow::{
26
    array::ArrayRef,
27
    datatypes::{DataType, Field},
28
};
29
30
use crate::covariance::CovarianceAccumulator;
31
use crate::stddev::StddevAccumulator;
32
use datafusion_common::{plan_err, Result, ScalarValue};
33
use datafusion_expr::{
34
    function::{AccumulatorArgs, StateFieldsArgs},
35
    type_coercion::aggregates::NUMERICS,
36
    utils::format_state_name,
37
    Accumulator, AggregateUDFImpl, Signature, Volatility,
38
};
39
use datafusion_functions_aggregate_common::stats::StatsType;
40
41
make_udaf_expr_and_func!(
42
    Correlation,
43
    corr,
44
    y x,
45
    "Correlation between two numeric values.",
46
    corr_udaf
47
);
48
49
#[derive(Debug)]
50
pub struct Correlation {
51
    signature: Signature,
52
}
53
54
impl Default for Correlation {
55
0
    fn default() -> Self {
56
0
        Self::new()
57
0
    }
58
}
59
60
impl Correlation {
61
    /// Create a new COVAR_POP aggregate function
62
0
    pub fn new() -> Self {
63
0
        Self {
64
0
            signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
65
0
        }
66
0
    }
67
}
68
69
impl AggregateUDFImpl for Correlation {
70
    /// Return a reference to Any that can be used for downcasting
71
0
    fn as_any(&self) -> &dyn Any {
72
0
        self
73
0
    }
74
75
0
    fn name(&self) -> &str {
76
0
        "corr"
77
0
    }
78
79
0
    fn signature(&self) -> &Signature {
80
0
        &self.signature
81
0
    }
82
83
0
    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
84
0
        if !arg_types[0].is_numeric() {
85
0
            return plan_err!("Correlation requires numeric input types");
86
0
        }
87
0
88
0
        Ok(DataType::Float64)
89
0
    }
90
91
0
    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
92
0
        Ok(Box::new(CorrelationAccumulator::try_new()?))
93
0
    }
94
95
0
    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
96
0
        let name = args.name;
97
0
        Ok(vec![
98
0
            Field::new(format_state_name(name, "count"), DataType::UInt64, true),
99
0
            Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
100
0
            Field::new(format_state_name(name, "m2_1"), DataType::Float64, true),
101
0
            Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
102
0
            Field::new(format_state_name(name, "m2_2"), DataType::Float64, true),
103
0
            Field::new(
104
0
                format_state_name(name, "algo_const"),
105
0
                DataType::Float64,
106
0
                true,
107
0
            ),
108
0
        ])
109
0
    }
110
}
111
112
/// An accumulator to compute correlation
113
#[derive(Debug)]
114
pub struct CorrelationAccumulator {
115
    covar: CovarianceAccumulator,
116
    stddev1: StddevAccumulator,
117
    stddev2: StddevAccumulator,
118
}
119
120
impl CorrelationAccumulator {
121
    /// Creates a new `CorrelationAccumulator`
122
0
    pub fn try_new() -> Result<Self> {
123
0
        Ok(Self {
124
0
            covar: CovarianceAccumulator::try_new(StatsType::Population)?,
125
0
            stddev1: StddevAccumulator::try_new(StatsType::Population)?,
126
0
            stddev2: StddevAccumulator::try_new(StatsType::Population)?,
127
        })
128
0
    }
129
}
130
131
impl Accumulator for CorrelationAccumulator {
132
0
    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
133
        // TODO: null input skipping logic duplicated across Correlation
134
        // and its children accumulators.
135
        // This could be simplified by splitting up input filtering and
136
        // calculation logic in children accumulators, and calling only
137
        // calculation part from Correlation
138
0
        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
139
0
            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
140
0
            let values1 = filter(&values[0], &mask)?;
141
0
            let values2 = filter(&values[1], &mask)?;
142
143
0
            vec![values1, values2]
144
        } else {
145
0
            values.to_vec()
146
        };
147
148
0
        self.covar.update_batch(&values)?;
149
0
        self.stddev1.update_batch(&values[0..1])?;
150
0
        self.stddev2.update_batch(&values[1..2])?;
151
0
        Ok(())
152
0
    }
153
154
0
    fn evaluate(&mut self) -> Result<ScalarValue> {
155
0
        let covar = self.covar.evaluate()?;
156
0
        let stddev1 = self.stddev1.evaluate()?;
157
0
        let stddev2 = self.stddev2.evaluate()?;
158
159
0
        if let ScalarValue::Float64(Some(c)) = covar {
160
0
            if let ScalarValue::Float64(Some(s1)) = stddev1 {
161
0
                if let ScalarValue::Float64(Some(s2)) = stddev2 {
162
0
                    if s1 == 0_f64 || s2 == 0_f64 {
163
0
                        return Ok(ScalarValue::Float64(Some(0_f64)));
164
                    } else {
165
0
                        return Ok(ScalarValue::Float64(Some(c / s1 / s2)));
166
                    }
167
0
                }
168
0
            }
169
0
        }
170
171
0
        Ok(ScalarValue::Float64(None))
172
0
    }
173
174
0
    fn size(&self) -> usize {
175
0
        std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar)
176
0
            + self.covar.size()
177
0
            - std::mem::size_of_val(&self.stddev1)
178
0
            + self.stddev1.size()
179
0
            - std::mem::size_of_val(&self.stddev2)
180
0
            + self.stddev2.size()
181
0
    }
182
183
0
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
184
0
        Ok(vec![
185
0
            ScalarValue::from(self.covar.get_count()),
186
0
            ScalarValue::from(self.covar.get_mean1()),
187
0
            ScalarValue::from(self.stddev1.get_m2()),
188
0
            ScalarValue::from(self.covar.get_mean2()),
189
0
            ScalarValue::from(self.stddev2.get_m2()),
190
0
            ScalarValue::from(self.covar.get_algo_const()),
191
0
        ])
192
0
    }
193
194
0
    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
195
0
        let states_c = [
196
0
            Arc::clone(&states[0]),
197
0
            Arc::clone(&states[1]),
198
0
            Arc::clone(&states[3]),
199
0
            Arc::clone(&states[5]),
200
0
        ];
201
0
        let states_s1 = [
202
0
            Arc::clone(&states[0]),
203
0
            Arc::clone(&states[1]),
204
0
            Arc::clone(&states[2]),
205
0
        ];
206
0
        let states_s2 = [
207
0
            Arc::clone(&states[0]),
208
0
            Arc::clone(&states[3]),
209
0
            Arc::clone(&states[4]),
210
0
        ];
211
0
212
0
        self.covar.merge_batch(&states_c)?;
213
0
        self.stddev1.merge_batch(&states_s1)?;
214
0
        self.stddev2.merge_batch(&states_s2)?;
215
0
        Ok(())
216
0
    }
217
218
0
    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
219
0
        let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
220
0
            let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
221
0
            let values1 = filter(&values[0], &mask)?;
222
0
            let values2 = filter(&values[1], &mask)?;
223
224
0
            vec![values1, values2]
225
        } else {
226
0
            values.to_vec()
227
        };
228
229
0
        self.covar.retract_batch(&values)?;
230
0
        self.stddev1.retract_batch(&values[0..1])?;
231
0
        self.stddev2.retract_batch(&values[1..2])?;
232
0
        Ok(())
233
0
    }
234
}