Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/physical-expr/src/expressions/negative.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Negation (-) expression
19
20
use std::any::Any;
21
use std::hash::{Hash, Hasher};
22
use std::sync::Arc;
23
24
use crate::physical_expr::down_cast_any_ref;
25
use crate::PhysicalExpr;
26
27
use arrow::{
28
    compute::kernels::numeric::neg_wrapping,
29
    datatypes::{DataType, Schema},
30
    record_batch::RecordBatch,
31
};
32
use datafusion_common::{plan_err, Result};
33
use datafusion_expr::interval_arithmetic::Interval;
34
use datafusion_expr::sort_properties::ExprProperties;
35
use datafusion_expr::{
36
    type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp},
37
    ColumnarValue,
38
};
39
40
/// Negative expression
41
#[derive(Debug, Hash)]
42
pub struct NegativeExpr {
43
    /// Input expression
44
    arg: Arc<dyn PhysicalExpr>,
45
}
46
47
impl NegativeExpr {
48
    /// Create new not expression
49
0
    pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
50
0
        Self { arg }
51
0
    }
52
53
    /// Get the input expression
54
0
    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
55
0
        &self.arg
56
0
    }
57
}
58
59
impl std::fmt::Display for NegativeExpr {
60
0
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
61
0
        write!(f, "(- {})", self.arg)
62
0
    }
63
}
64
65
impl PhysicalExpr for NegativeExpr {
66
    /// Return a reference to Any that can be used for downcasting
67
0
    fn as_any(&self) -> &dyn Any {
68
0
        self
69
0
    }
70
71
0
    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
72
0
        self.arg.data_type(input_schema)
73
0
    }
74
75
0
    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
76
0
        self.arg.nullable(input_schema)
77
0
    }
78
79
0
    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
80
0
        let arg = self.arg.evaluate(batch)?;
81
0
        match arg {
82
0
            ColumnarValue::Array(array) => {
83
0
                let result = neg_wrapping(array.as_ref())?;
84
0
                Ok(ColumnarValue::Array(result))
85
            }
86
0
            ColumnarValue::Scalar(scalar) => {
87
0
                Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
88
            }
89
        }
90
0
    }
91
92
0
    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
93
0
        vec![&self.arg]
94
0
    }
95
96
0
    fn with_new_children(
97
0
        self: Arc<Self>,
98
0
        children: Vec<Arc<dyn PhysicalExpr>>,
99
0
    ) -> Result<Arc<dyn PhysicalExpr>> {
100
0
        Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0]))))
101
0
    }
102
103
0
    fn dyn_hash(&self, state: &mut dyn Hasher) {
104
0
        let mut s = state;
105
0
        self.hash(&mut s);
106
0
    }
107
108
    /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval.
109
    /// It replaces the upper and lower bounds after multiplying them with -1.
110
    /// Ex: `(a, b]` => `[-b, -a)`
111
0
    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
112
0
        Interval::try_new(
113
0
            children[0].upper().arithmetic_negate()?,
114
0
            children[0].lower().arithmetic_negate()?,
115
        )
116
0
    }
117
118
    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing `interval` given that
119
    /// given the input interval is known to be `children`.
120
0
    fn propagate_constraints(
121
0
        &self,
122
0
        interval: &Interval,
123
0
        children: &[&Interval],
124
0
    ) -> Result<Option<Vec<Interval>>> {
125
0
        let child_interval = children[0];
126
0
        let negated_interval = Interval::try_new(
127
0
            interval.upper().arithmetic_negate()?,
128
0
            interval.lower().arithmetic_negate()?,
129
0
        )?;
130
131
0
        Ok(child_interval
132
0
            .intersect(negated_interval)?
133
0
            .map(|result| vec![result]))
134
0
    }
135
136
    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
137
0
    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
138
0
        Ok(ExprProperties {
139
0
            sort_properties: -children[0].sort_properties,
140
0
            range: children[0].range.clone().arithmetic_negate()?,
141
        })
142
0
    }
143
}
144
145
impl PartialEq<dyn Any> for NegativeExpr {
146
0
    fn eq(&self, other: &dyn Any) -> bool {
147
0
        down_cast_any_ref(other)
148
0
            .downcast_ref::<Self>()
149
0
            .map(|x| self.arg.eq(&x.arg))
150
0
            .unwrap_or(false)
151
0
    }
152
}
153
154
/// Creates a unary expression NEGATIVE
155
///
156
/// # Errors
157
///
158
/// This function errors when the argument's type is not signed numeric
159
0
pub fn negative(
160
0
    arg: Arc<dyn PhysicalExpr>,
161
0
    input_schema: &Schema,
162
0
) -> Result<Arc<dyn PhysicalExpr>> {
163
0
    let data_type = arg.data_type(input_schema)?;
164
0
    if is_null(&data_type) {
165
0
        Ok(arg)
166
0
    } else if !is_signed_numeric(&data_type)
167
0
        && !is_interval(&data_type)
168
0
        && !is_timestamp(&data_type)
169
    {
170
0
        plan_err!("Negation only supports numeric, interval and timestamp types")
171
    } else {
172
0
        Ok(Arc::new(NegativeExpr::new(arg)))
173
    }
174
0
}
175
176
#[cfg(test)]
177
mod tests {
178
    use super::*;
179
    use crate::expressions::{col, Column};
180
181
    use arrow::array::*;
182
    use arrow::datatypes::*;
183
    use arrow_schema::DataType::{Float32, Float64, Int16, Int32, Int64, Int8};
184
    use datafusion_common::cast::as_primitive_array;
185
    use datafusion_common::DataFusionError;
186
187
    use paste::paste;
188
189
    macro_rules! test_array_negative_op {
190
        ($DATA_TY:tt, $($VALUE:expr),*   ) => {
191
            let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]);
192
            let expr = negative(col("a", &schema)?, &schema)?;
193
            assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY);
194
            assert!(expr.nullable(&schema)?);
195
            let mut arr = Vec::new();
196
            let mut arr_expected = Vec::new();
197
            $(
198
                arr.push(Some($VALUE));
199
                arr_expected.push(Some(-$VALUE));
200
            )+
201
            arr.push(None);
202
            arr_expected.push(None);
203
            let input = paste!{[<$DATA_TY Array>]::from(arr)};
204
            let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)};
205
            let batch =
206
                RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?;
207
            let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array");
208
            let result =
209
                as_primitive_array(&result).expect(format!("failed to downcast to {:?}Array", $DATA_TY).as_str());
210
            assert_eq!(result, expected);
211
        };
212
    }
213
214
    #[test]
215
    fn array_negative_op() -> Result<()> {
216
        test_array_negative_op!(Int8, 2i8, 1i8);
217
        test_array_negative_op!(Int16, 234i16, 123i16);
218
        test_array_negative_op!(Int32, 2345i32, 1234i32);
219
        test_array_negative_op!(Int64, 23456i64, 12345i64);
220
        test_array_negative_op!(Float32, 2345.0f32, 1234.0f32);
221
        test_array_negative_op!(Float64, 23456.0f64, 12345.0f64);
222
        Ok(())
223
    }
224
225
    #[test]
226
    fn test_evaluate_bounds() -> Result<()> {
227
        let negative_expr = NegativeExpr {
228
            arg: Arc::new(Column::new("a", 0)),
229
        };
230
        let child_interval = Interval::make(Some(-2), Some(1))?;
231
        let negative_expr_interval = Interval::make(Some(-1), Some(2))?;
232
        assert_eq!(
233
            negative_expr.evaluate_bounds(&[&child_interval])?,
234
            negative_expr_interval
235
        );
236
        Ok(())
237
    }
238
239
    #[test]
240
    fn test_propagate_constraints() -> Result<()> {
241
        let negative_expr = NegativeExpr {
242
            arg: Arc::new(Column::new("a", 0)),
243
        };
244
        let original_child_interval = Interval::make(Some(-2), Some(3))?;
245
        let negative_expr_interval = Interval::make(Some(0), Some(4))?;
246
        let after_propagation = Some(vec![Interval::make(Some(-2), Some(0))?]);
247
        assert_eq!(
248
            negative_expr.propagate_constraints(
249
                &negative_expr_interval,
250
                &[&original_child_interval]
251
            )?,
252
            after_propagation
253
        );
254
        Ok(())
255
    }
256
257
    #[test]
258
    fn test_negation_valid_types() -> Result<()> {
259
        let negatable_types = [
260
            DataType::Int8,
261
            DataType::Timestamp(TimeUnit::Second, None),
262
            DataType::Interval(IntervalUnit::YearMonth),
263
        ];
264
        for negatable_type in negatable_types {
265
            let schema = Schema::new(vec![Field::new("a", negatable_type, true)]);
266
            let _expr = negative(col("a", &schema)?, &schema)?;
267
        }
268
        Ok(())
269
    }
270
271
    #[test]
272
    fn test_negation_invalid_types() -> Result<()> {
273
        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
274
        let expr = negative(col("a", &schema)?, &schema).unwrap_err();
275
        matches!(expr, DataFusionError::Plan(_));
276
        Ok(())
277
    }
278
}