Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/functions-aggregate-common/src/utils.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
use std::sync::Arc;
19
20
use arrow::array::{ArrayRef, AsArray};
21
use arrow::datatypes::ArrowNativeType;
22
use arrow::{
23
    array::ArrowNativeTypeOp,
24
    compute::SortOptions,
25
    datatypes::{
26
        DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
27
        TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
28
        ToByteSlice,
29
    },
30
};
31
use datafusion_common::{exec_err, DataFusionError, Result};
32
use datafusion_expr_common::accumulator::Accumulator;
33
use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
34
35
/// Convert scalar values from an accumulator into arrays.
36
0
pub fn get_accum_scalar_values_as_arrays(
37
0
    accum: &mut dyn Accumulator,
38
0
) -> Result<Vec<ArrayRef>> {
39
0
    accum
40
0
        .state()?
41
0
        .iter()
42
0
        .map(|s| s.to_array_of_size(1))
43
0
        .collect()
44
0
}
45
46
/// Adjust array type metadata if needed
47
///
48
/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
49
/// default precision and scale, this function adjusts the output to
50
/// match `data_type`, if necessary
51
10
pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
52
10
    let array = match 
data_type0
{
53
0
        DataType::Decimal128(p, s) => Arc::new(
54
0
            array
55
0
                .as_primitive::<Decimal128Type>()
56
0
                .clone()
57
0
                .with_precision_and_scale(*p, *s)?,
58
        ) as ArrayRef,
59
0
        DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
60
0
            array
61
0
                .as_primitive::<TimestampNanosecondType>()
62
0
                .clone()
63
0
                .with_timezone_opt(tz.clone()),
64
0
        ),
65
0
        DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
66
0
            array
67
0
                .as_primitive::<TimestampMicrosecondType>()
68
0
                .clone()
69
0
                .with_timezone_opt(tz.clone()),
70
0
        ),
71
0
        DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
72
0
            array
73
0
                .as_primitive::<TimestampMillisecondType>()
74
0
                .clone()
75
0
                .with_timezone_opt(tz.clone()),
76
0
        ),
77
0
        DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
78
0
            array
79
0
                .as_primitive::<TimestampSecondType>()
80
0
                .clone()
81
0
                .with_timezone_opt(tz.clone()),
82
0
        ),
83
        // no adjustment needed for other arrays
84
10
        _ => array,
85
    };
86
10
    Ok(array)
87
10
}
88
89
/// Construct corresponding fields for lexicographical ordering requirement expression
90
16
pub fn ordering_fields(
91
16
    ordering_req: &[PhysicalSortExpr],
92
16
    // Data type of each expression in the ordering requirement
93
16
    data_types: &[DataType],
94
16
) -> Vec<Field> {
95
16
    ordering_req
96
16
        .iter()
97
16
        .zip(data_types.iter())
98
22
        .map(|(sort_expr, dtype)| {
99
22
            Field::new(
100
22
                sort_expr.expr.to_string().as_str(),
101
22
                dtype.clone(),
102
22
                // Multi partitions may be empty hence field should be nullable.
103
22
                true,
104
22
            )
105
22
        })
106
16
        .collect()
107
16
}
108
109
/// Selects the sort option attribute from all the given `PhysicalSortExpr`s.
110
66
pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> {
111
66
    ordering_req.iter().map(|item| item.options).collect()
112
66
}
113
114
/// A wrapper around a type to provide hash for floats
115
#[derive(Copy, Clone, Debug)]
116
pub struct Hashable<T>(pub T);
117
118
impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
119
0
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120
0
        self.0.to_byte_slice().hash(state)
121
0
    }
122
}
123
124
impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
125
0
    fn eq(&self, other: &Self) -> bool {
126
0
        self.0.is_eq(other.0)
127
0
    }
128
}
129
130
impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
131
132
/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
133
///
134
/// This is needed because different precisions for Decimal128/Decimal256 can
135
/// store different ranges of values and thus sum/count may not fit in
136
/// the target type.
137
///
138
/// For example, the precision is 3, the max of value is `999` and the min
139
/// value is `-999`
140
pub struct DecimalAverager<T: DecimalType> {
141
    /// scale factor for sum values (10^sum_scale)
142
    sum_mul: T::Native,
143
    /// scale factor for target (10^target_scale)
144
    target_mul: T::Native,
145
    /// the output precision
146
    target_precision: u8,
147
}
148
149
impl<T: DecimalType> DecimalAverager<T> {
150
    /// Create a new `DecimalAverager`:
151
    ///
152
    /// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
153
    /// * target_precision: the output precision
154
    /// * target_scale: the output scale
155
    ///
156
    /// Errors if the resulting data can not be stored
157
0
    pub fn try_new(
158
0
        sum_scale: i8,
159
0
        target_precision: u8,
160
0
        target_scale: i8,
161
0
    ) -> Result<Self> {
162
0
        let sum_mul = T::Native::from_usize(10_usize)
163
0
            .map(|b| b.pow_wrapping(sum_scale as u32))
164
0
            .ok_or(DataFusionError::Internal(
165
0
                "Failed to compute sum_mul in DecimalAverager".to_string(),
166
0
            ))?;
167
168
0
        let target_mul = T::Native::from_usize(10_usize)
169
0
            .map(|b| b.pow_wrapping(target_scale as u32))
170
0
            .ok_or(DataFusionError::Internal(
171
0
                "Failed to compute target_mul in DecimalAverager".to_string(),
172
0
            ))?;
173
174
0
        if target_mul >= sum_mul {
175
0
            Ok(Self {
176
0
                sum_mul,
177
0
                target_mul,
178
0
                target_precision,
179
0
            })
180
        } else {
181
            // can't convert the lit decimal to the returned data type
182
0
            exec_err!("Arithmetic Overflow in AvgAccumulator")
183
        }
184
0
    }
185
186
    /// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
187
    /// target_scale and target_precision and reporting overflow.
188
    ///
189
    /// * sum: The total sum value stored as Decimal128 with sum_scale
190
    ///   (passed to `Self::try_new`)
191
    /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
192
    #[inline(always)]
193
0
    pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
194
0
        if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
195
0
            let new_value = value.div_wrapping(count);
196
0
197
0
            let validate =
198
0
                T::validate_decimal_precision(new_value, self.target_precision);
199
0
200
0
            if validate.is_ok() {
201
0
                Ok(new_value)
202
            } else {
203
0
                exec_err!("Arithmetic Overflow in AvgAccumulator")
204
            }
205
        } else {
206
            // can't convert the lit decimal to the returned data type
207
0
            exec_err!("Arithmetic Overflow in AvgAccumulator")
208
        }
209
0
    }
210
}