/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate/src/correlation.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! [`Correlation`]: correlation sample aggregations. |
19 | | |
20 | | use std::any::Any; |
21 | | use std::fmt::Debug; |
22 | | use std::sync::Arc; |
23 | | |
24 | | use arrow::compute::{and, filter, is_not_null}; |
25 | | use arrow::{ |
26 | | array::ArrayRef, |
27 | | datatypes::{DataType, Field}, |
28 | | }; |
29 | | |
30 | | use crate::covariance::CovarianceAccumulator; |
31 | | use crate::stddev::StddevAccumulator; |
32 | | use datafusion_common::{plan_err, Result, ScalarValue}; |
33 | | use datafusion_expr::{ |
34 | | function::{AccumulatorArgs, StateFieldsArgs}, |
35 | | type_coercion::aggregates::NUMERICS, |
36 | | utils::format_state_name, |
37 | | Accumulator, AggregateUDFImpl, Signature, Volatility, |
38 | | }; |
39 | | use datafusion_functions_aggregate_common::stats::StatsType; |
40 | | |
41 | | make_udaf_expr_and_func!( |
42 | | Correlation, |
43 | | corr, |
44 | | y x, |
45 | | "Correlation between two numeric values.", |
46 | | corr_udaf |
47 | | ); |
48 | | |
49 | | #[derive(Debug)] |
50 | | pub struct Correlation { |
51 | | signature: Signature, |
52 | | } |
53 | | |
54 | | impl Default for Correlation { |
55 | 0 | fn default() -> Self { |
56 | 0 | Self::new() |
57 | 0 | } |
58 | | } |
59 | | |
60 | | impl Correlation { |
61 | | /// Create a new COVAR_POP aggregate function |
62 | 0 | pub fn new() -> Self { |
63 | 0 | Self { |
64 | 0 | signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), |
65 | 0 | } |
66 | 0 | } |
67 | | } |
68 | | |
69 | | impl AggregateUDFImpl for Correlation { |
70 | | /// Return a reference to Any that can be used for downcasting |
71 | 0 | fn as_any(&self) -> &dyn Any { |
72 | 0 | self |
73 | 0 | } |
74 | | |
75 | 0 | fn name(&self) -> &str { |
76 | 0 | "corr" |
77 | 0 | } |
78 | | |
79 | 0 | fn signature(&self) -> &Signature { |
80 | 0 | &self.signature |
81 | 0 | } |
82 | | |
83 | 0 | fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { |
84 | 0 | if !arg_types[0].is_numeric() { |
85 | 0 | return plan_err!("Correlation requires numeric input types"); |
86 | 0 | } |
87 | 0 |
|
88 | 0 | Ok(DataType::Float64) |
89 | 0 | } |
90 | | |
91 | 0 | fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { |
92 | 0 | Ok(Box::new(CorrelationAccumulator::try_new()?)) |
93 | 0 | } |
94 | | |
95 | 0 | fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> { |
96 | 0 | let name = args.name; |
97 | 0 | Ok(vec![ |
98 | 0 | Field::new(format_state_name(name, "count"), DataType::UInt64, true), |
99 | 0 | Field::new(format_state_name(name, "mean1"), DataType::Float64, true), |
100 | 0 | Field::new(format_state_name(name, "m2_1"), DataType::Float64, true), |
101 | 0 | Field::new(format_state_name(name, "mean2"), DataType::Float64, true), |
102 | 0 | Field::new(format_state_name(name, "m2_2"), DataType::Float64, true), |
103 | 0 | Field::new( |
104 | 0 | format_state_name(name, "algo_const"), |
105 | 0 | DataType::Float64, |
106 | 0 | true, |
107 | 0 | ), |
108 | 0 | ]) |
109 | 0 | } |
110 | | } |
111 | | |
112 | | /// An accumulator to compute correlation |
113 | | #[derive(Debug)] |
114 | | pub struct CorrelationAccumulator { |
115 | | covar: CovarianceAccumulator, |
116 | | stddev1: StddevAccumulator, |
117 | | stddev2: StddevAccumulator, |
118 | | } |
119 | | |
120 | | impl CorrelationAccumulator { |
121 | | /// Creates a new `CorrelationAccumulator` |
122 | 0 | pub fn try_new() -> Result<Self> { |
123 | 0 | Ok(Self { |
124 | 0 | covar: CovarianceAccumulator::try_new(StatsType::Population)?, |
125 | 0 | stddev1: StddevAccumulator::try_new(StatsType::Population)?, |
126 | 0 | stddev2: StddevAccumulator::try_new(StatsType::Population)?, |
127 | | }) |
128 | 0 | } |
129 | | } |
130 | | |
131 | | impl Accumulator for CorrelationAccumulator { |
132 | 0 | fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
133 | | // TODO: null input skipping logic duplicated across Correlation |
134 | | // and its children accumulators. |
135 | | // This could be simplified by splitting up input filtering and |
136 | | // calculation logic in children accumulators, and calling only |
137 | | // calculation part from Correlation |
138 | 0 | let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { |
139 | 0 | let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; |
140 | 0 | let values1 = filter(&values[0], &mask)?; |
141 | 0 | let values2 = filter(&values[1], &mask)?; |
142 | | |
143 | 0 | vec![values1, values2] |
144 | | } else { |
145 | 0 | values.to_vec() |
146 | | }; |
147 | | |
148 | 0 | self.covar.update_batch(&values)?; |
149 | 0 | self.stddev1.update_batch(&values[0..1])?; |
150 | 0 | self.stddev2.update_batch(&values[1..2])?; |
151 | 0 | Ok(()) |
152 | 0 | } |
153 | | |
154 | 0 | fn evaluate(&mut self) -> Result<ScalarValue> { |
155 | 0 | let covar = self.covar.evaluate()?; |
156 | 0 | let stddev1 = self.stddev1.evaluate()?; |
157 | 0 | let stddev2 = self.stddev2.evaluate()?; |
158 | | |
159 | 0 | if let ScalarValue::Float64(Some(c)) = covar { |
160 | 0 | if let ScalarValue::Float64(Some(s1)) = stddev1 { |
161 | 0 | if let ScalarValue::Float64(Some(s2)) = stddev2 { |
162 | 0 | if s1 == 0_f64 || s2 == 0_f64 { |
163 | 0 | return Ok(ScalarValue::Float64(Some(0_f64))); |
164 | | } else { |
165 | 0 | return Ok(ScalarValue::Float64(Some(c / s1 / s2))); |
166 | | } |
167 | 0 | } |
168 | 0 | } |
169 | 0 | } |
170 | | |
171 | 0 | Ok(ScalarValue::Float64(None)) |
172 | 0 | } |
173 | | |
174 | 0 | fn size(&self) -> usize { |
175 | 0 | std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) |
176 | 0 | + self.covar.size() |
177 | 0 | - std::mem::size_of_val(&self.stddev1) |
178 | 0 | + self.stddev1.size() |
179 | 0 | - std::mem::size_of_val(&self.stddev2) |
180 | 0 | + self.stddev2.size() |
181 | 0 | } |
182 | | |
183 | 0 | fn state(&mut self) -> Result<Vec<ScalarValue>> { |
184 | 0 | Ok(vec![ |
185 | 0 | ScalarValue::from(self.covar.get_count()), |
186 | 0 | ScalarValue::from(self.covar.get_mean1()), |
187 | 0 | ScalarValue::from(self.stddev1.get_m2()), |
188 | 0 | ScalarValue::from(self.covar.get_mean2()), |
189 | 0 | ScalarValue::from(self.stddev2.get_m2()), |
190 | 0 | ScalarValue::from(self.covar.get_algo_const()), |
191 | 0 | ]) |
192 | 0 | } |
193 | | |
194 | 0 | fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { |
195 | 0 | let states_c = [ |
196 | 0 | Arc::clone(&states[0]), |
197 | 0 | Arc::clone(&states[1]), |
198 | 0 | Arc::clone(&states[3]), |
199 | 0 | Arc::clone(&states[5]), |
200 | 0 | ]; |
201 | 0 | let states_s1 = [ |
202 | 0 | Arc::clone(&states[0]), |
203 | 0 | Arc::clone(&states[1]), |
204 | 0 | Arc::clone(&states[2]), |
205 | 0 | ]; |
206 | 0 | let states_s2 = [ |
207 | 0 | Arc::clone(&states[0]), |
208 | 0 | Arc::clone(&states[3]), |
209 | 0 | Arc::clone(&states[4]), |
210 | 0 | ]; |
211 | 0 |
|
212 | 0 | self.covar.merge_batch(&states_c)?; |
213 | 0 | self.stddev1.merge_batch(&states_s1)?; |
214 | 0 | self.stddev2.merge_batch(&states_s2)?; |
215 | 0 | Ok(()) |
216 | 0 | } |
217 | | |
218 | 0 | fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { |
219 | 0 | let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { |
220 | 0 | let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; |
221 | 0 | let values1 = filter(&values[0], &mask)?; |
222 | 0 | let values2 = filter(&values[1], &mask)?; |
223 | | |
224 | 0 | vec![values1, values2] |
225 | | } else { |
226 | 0 | values.to_vec() |
227 | | }; |
228 | | |
229 | 0 | self.covar.retract_batch(&values)?; |
230 | 0 | self.stddev1.retract_batch(&values[0..1])?; |
231 | 0 | self.stddev2.retract_batch(&values[1..2])?; |
232 | 0 | Ok(()) |
233 | 0 | } |
234 | | } |