/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/utils.rs
Line | Count | Source (jump to first uncovered line) |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use std::sync::Arc; |
19 | | |
20 | | use arrow::array::{ArrayRef, AsArray}; |
21 | | use arrow::datatypes::ArrowNativeType; |
22 | | use arrow::{ |
23 | | array::ArrowNativeTypeOp, |
24 | | compute::SortOptions, |
25 | | datatypes::{ |
26 | | DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType, |
27 | | TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, |
28 | | ToByteSlice, |
29 | | }, |
30 | | }; |
31 | | use datafusion_common::{exec_err, DataFusionError, Result}; |
32 | | use datafusion_expr_common::accumulator::Accumulator; |
33 | | use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; |
34 | | |
35 | | /// Convert scalar values from an accumulator into arrays. |
36 | 0 | pub fn get_accum_scalar_values_as_arrays( |
37 | 0 | accum: &mut dyn Accumulator, |
38 | 0 | ) -> Result<Vec<ArrayRef>> { |
39 | 0 | accum |
40 | 0 | .state()? |
41 | 0 | .iter() |
42 | 0 | .map(|s| s.to_array_of_size(1)) |
43 | 0 | .collect() |
44 | 0 | } |
45 | | |
46 | | /// Adjust array type metadata if needed |
47 | | /// |
48 | | /// Since `Decimal128Arrays` created from `Vec<NativeType>` have |
49 | | /// default precision and scale, this function adjusts the output to |
50 | | /// match `data_type`, if necessary |
51 | 10 | pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> { |
52 | 10 | let array = match data_type0 { |
53 | 0 | DataType::Decimal128(p, s) => Arc::new( |
54 | 0 | array |
55 | 0 | .as_primitive::<Decimal128Type>() |
56 | 0 | .clone() |
57 | 0 | .with_precision_and_scale(*p, *s)?, |
58 | | ) as ArrayRef, |
59 | 0 | DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( |
60 | 0 | array |
61 | 0 | .as_primitive::<TimestampNanosecondType>() |
62 | 0 | .clone() |
63 | 0 | .with_timezone_opt(tz.clone()), |
64 | 0 | ), |
65 | 0 | DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new( |
66 | 0 | array |
67 | 0 | .as_primitive::<TimestampMicrosecondType>() |
68 | 0 | .clone() |
69 | 0 | .with_timezone_opt(tz.clone()), |
70 | 0 | ), |
71 | 0 | DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new( |
72 | 0 | array |
73 | 0 | .as_primitive::<TimestampMillisecondType>() |
74 | 0 | .clone() |
75 | 0 | .with_timezone_opt(tz.clone()), |
76 | 0 | ), |
77 | 0 | DataType::Timestamp(TimeUnit::Second, tz) => Arc::new( |
78 | 0 | array |
79 | 0 | .as_primitive::<TimestampSecondType>() |
80 | 0 | .clone() |
81 | 0 | .with_timezone_opt(tz.clone()), |
82 | 0 | ), |
83 | | // no adjustment needed for other arrays |
84 | 10 | _ => array, |
85 | | }; |
86 | 10 | Ok(array) |
87 | 10 | } |
88 | | |
89 | | /// Construct corresponding fields for lexicographical ordering requirement expression |
90 | 16 | pub fn ordering_fields( |
91 | 16 | ordering_req: &[PhysicalSortExpr], |
92 | 16 | // Data type of each expression in the ordering requirement |
93 | 16 | data_types: &[DataType], |
94 | 16 | ) -> Vec<Field> { |
95 | 16 | ordering_req |
96 | 16 | .iter() |
97 | 16 | .zip(data_types.iter()) |
98 | 22 | .map(|(sort_expr, dtype)| { |
99 | 22 | Field::new( |
100 | 22 | sort_expr.expr.to_string().as_str(), |
101 | 22 | dtype.clone(), |
102 | 22 | // Multi partitions may be empty hence field should be nullable. |
103 | 22 | true, |
104 | 22 | ) |
105 | 22 | }) |
106 | 16 | .collect() |
107 | 16 | } |
108 | | |
109 | | /// Selects the sort option attribute from all the given `PhysicalSortExpr`s. |
110 | 66 | pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> { |
111 | 66 | ordering_req.iter().map(|item| item.options).collect() |
112 | 66 | } |
113 | | |
114 | | /// A wrapper around a type to provide hash for floats |
115 | | #[derive(Copy, Clone, Debug)] |
116 | | pub struct Hashable<T>(pub T); |
117 | | |
118 | | impl<T: ToByteSlice> std::hash::Hash for Hashable<T> { |
119 | 0 | fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
120 | 0 | self.0.to_byte_slice().hash(state) |
121 | 0 | } |
122 | | } |
123 | | |
124 | | impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> { |
125 | 0 | fn eq(&self, other: &Self) -> bool { |
126 | 0 | self.0.is_eq(other.0) |
127 | 0 | } |
128 | | } |
129 | | |
130 | | impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {} |
131 | | |
132 | | /// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow |
133 | | /// |
134 | | /// This is needed because different precisions for Decimal128/Decimal256 can |
135 | | /// store different ranges of values and thus sum/count may not fit in |
136 | | /// the target type. |
137 | | /// |
138 | | /// For example, the precision is 3, the max of value is `999` and the min |
139 | | /// value is `-999` |
140 | | pub struct DecimalAverager<T: DecimalType> { |
141 | | /// scale factor for sum values (10^sum_scale) |
142 | | sum_mul: T::Native, |
143 | | /// scale factor for target (10^target_scale) |
144 | | target_mul: T::Native, |
145 | | /// the output precision |
146 | | target_precision: u8, |
147 | | } |
148 | | |
149 | | impl<T: DecimalType> DecimalAverager<T> { |
150 | | /// Create a new `DecimalAverager`: |
151 | | /// |
152 | | /// * sum_scale: the scale of `sum` values passed to [`Self::avg`] |
153 | | /// * target_precision: the output precision |
154 | | /// * target_scale: the output scale |
155 | | /// |
156 | | /// Errors if the resulting data can not be stored |
157 | 0 | pub fn try_new( |
158 | 0 | sum_scale: i8, |
159 | 0 | target_precision: u8, |
160 | 0 | target_scale: i8, |
161 | 0 | ) -> Result<Self> { |
162 | 0 | let sum_mul = T::Native::from_usize(10_usize) |
163 | 0 | .map(|b| b.pow_wrapping(sum_scale as u32)) |
164 | 0 | .ok_or(DataFusionError::Internal( |
165 | 0 | "Failed to compute sum_mul in DecimalAverager".to_string(), |
166 | 0 | ))?; |
167 | | |
168 | 0 | let target_mul = T::Native::from_usize(10_usize) |
169 | 0 | .map(|b| b.pow_wrapping(target_scale as u32)) |
170 | 0 | .ok_or(DataFusionError::Internal( |
171 | 0 | "Failed to compute target_mul in DecimalAverager".to_string(), |
172 | 0 | ))?; |
173 | | |
174 | 0 | if target_mul >= sum_mul { |
175 | 0 | Ok(Self { |
176 | 0 | sum_mul, |
177 | 0 | target_mul, |
178 | 0 | target_precision, |
179 | 0 | }) |
180 | | } else { |
181 | | // can't convert the lit decimal to the returned data type |
182 | 0 | exec_err!("Arithmetic Overflow in AvgAccumulator") |
183 | | } |
184 | 0 | } |
185 | | |
186 | | /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with |
187 | | /// target_scale and target_precision and reporting overflow. |
188 | | /// |
189 | | /// * sum: The total sum value stored as Decimal128 with sum_scale |
190 | | /// (passed to `Self::try_new`) |
191 | | /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value) |
192 | | #[inline(always)] |
193 | 0 | pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> { |
194 | 0 | if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) { |
195 | 0 | let new_value = value.div_wrapping(count); |
196 | 0 |
|
197 | 0 | let validate = |
198 | 0 | T::validate_decimal_precision(new_value, self.target_precision); |
199 | 0 |
|
200 | 0 | if validate.is_ok() { |
201 | 0 | Ok(new_value) |
202 | | } else { |
203 | 0 | exec_err!("Arithmetic Overflow in AvgAccumulator") |
204 | | } |
205 | | } else { |
206 | | // can't convert the lit decimal to the returned data type |
207 | 0 | exec_err!("Arithmetic Overflow in AvgAccumulator") |
208 | | } |
209 | 0 | } |
210 | | } |