Coverage Report

Created: 2024-10-13 08:39

/Users/andrewlamb/Software/datafusion/datafusion/common/src/rounding.rs
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Floating point rounding mode utility library
19
//! TODO: Remove this custom implementation and the "libc" dependency when
20
//!       floating-point rounding mode manipulation functions become available
21
//!       in Rust.
22
23
use std::ops::{Add, BitAnd, Sub};
24
25
use crate::Result;
26
use crate::ScalarValue;
27
28
// Define constants for ARM
29
#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
30
const FE_UPWARD: i32 = 0x00400000;
31
#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
32
const FE_DOWNWARD: i32 = 0x00800000;
33
34
// Define constants for x86_64
35
#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
36
const FE_UPWARD: i32 = 0x0800;
37
#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
38
const FE_DOWNWARD: i32 = 0x0400;
39
40
#[cfg(all(
41
    any(target_arch = "x86_64", target_arch = "aarch64"),
42
    not(target_os = "windows")
43
))]
44
extern crate libc;
45
46
#[cfg(all(
47
    any(target_arch = "x86_64", target_arch = "aarch64"),
48
    not(target_os = "windows")
49
))]
50
extern "C" {
51
    fn fesetround(round: i32);
52
    fn fegetround() -> i32;
53
}
54
55
/// A trait to manipulate floating-point types with bitwise operations.
56
/// Provides functions to convert a floating-point value to/from its bitwise
57
/// representation as well as utility methods to handle special values.
58
pub trait FloatBits {
59
    /// The integer type used for bitwise operations.
60
    type Item: Copy
61
        + PartialEq
62
        + BitAnd<Output = Self::Item>
63
        + Add<Output = Self::Item>
64
        + Sub<Output = Self::Item>;
65
66
    /// The smallest positive floating-point value representable by this type.
67
    const TINY_BITS: Self::Item;
68
69
    /// The smallest (in magnitude) negative floating-point value representable by this type.
70
    const NEG_TINY_BITS: Self::Item;
71
72
    /// A mask to clear the sign bit of the floating-point value's bitwise representation.
73
    const CLEAR_SIGN_MASK: Self::Item;
74
75
    /// The integer value 1, used in bitwise operations.
76
    const ONE: Self::Item;
77
78
    /// The integer value 0, used in bitwise operations.
79
    const ZERO: Self::Item;
80
81
    /// Converts the floating-point value to its bitwise representation.
82
    fn to_bits(self) -> Self::Item;
83
84
    /// Converts the bitwise representation to the corresponding floating-point value.
85
    fn from_bits(bits: Self::Item) -> Self;
86
87
    /// Returns true if the floating-point value is NaN (not a number).
88
    fn float_is_nan(self) -> bool;
89
90
    /// Returns the positive infinity value for this floating-point type.
91
    fn infinity() -> Self;
92
93
    /// Returns the negative infinity value for this floating-point type.
94
    fn neg_infinity() -> Self;
95
}
96
97
impl FloatBits for f32 {
98
    type Item = u32;
99
    const TINY_BITS: u32 = 0x1; // Smallest positive f32.
100
    const NEG_TINY_BITS: u32 = 0x8000_0001; // Smallest (in magnitude) negative f32.
101
    const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
102
    const ONE: Self::Item = 1;
103
    const ZERO: Self::Item = 0;
104
105
0
    fn to_bits(self) -> Self::Item {
106
0
        self.to_bits()
107
0
    }
108
109
0
    fn from_bits(bits: Self::Item) -> Self {
110
0
        f32::from_bits(bits)
111
0
    }
112
113
0
    fn float_is_nan(self) -> bool {
114
0
        self.is_nan()
115
0
    }
116
117
0
    fn infinity() -> Self {
118
0
        f32::INFINITY
119
0
    }
120
121
0
    fn neg_infinity() -> Self {
122
0
        f32::NEG_INFINITY
123
0
    }
124
}
125
126
impl FloatBits for f64 {
127
    type Item = u64;
128
    const TINY_BITS: u64 = 0x1; // Smallest positive f64.
129
    const NEG_TINY_BITS: u64 = 0x8000_0000_0000_0001; // Smallest (in magnitude) negative f64.
130
    const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff;
131
    const ONE: Self::Item = 1;
132
    const ZERO: Self::Item = 0;
133
134
2.72k
    fn to_bits(self) -> Self::Item {
135
2.72k
        self.to_bits()
136
2.72k
    }
137
138
1.36k
    fn from_bits(bits: Self::Item) -> Self {
139
1.36k
        f64::from_bits(bits)
140
1.36k
    }
141
142
1.36k
    fn float_is_nan(self) -> bool {
143
1.36k
        self.is_nan()
144
1.36k
    }
145
146
1.36k
    fn infinity() -> Self {
147
1.36k
        f64::INFINITY
148
1.36k
    }
149
150
0
    fn neg_infinity() -> Self {
151
0
        f64::NEG_INFINITY
152
0
    }
153
}
154
155
/// Returns the next representable floating-point value greater than the input value.
156
///
157
/// This function takes a floating-point value that implements the FloatBits trait,
158
/// calculates the next representable value greater than the input, and returns it.
159
///
160
/// If the input value is NaN or positive infinity, the function returns the input value.
161
///
162
/// # Examples
163
///
164
/// ```
165
/// use datafusion_common::rounding::next_up;
166
///
167
/// let f: f32 = 1.0;
168
/// let next_f = next_up(f);
169
/// assert_eq!(next_f, 1.0000001);
170
/// ```
171
1.36k
pub fn next_up<F: FloatBits + Copy>(float: F) -> F {
172
1.36k
    let bits = float.to_bits();
173
1.36k
    if float.float_is_nan() || bits == F::infinity().to_bits() {
174
0
        return float;
175
1.36k
    }
176
1.36k
177
1.36k
    let abs = bits & F::CLEAR_SIGN_MASK;
178
1.36k
    let next_bits = if abs == F::ZERO {
179
0
        F::TINY_BITS
180
1.36k
    } else if bits == abs {
181
1.36k
        bits + F::ONE
182
    } else {
183
0
        bits - F::ONE
184
    };
185
1.36k
    F::from_bits(next_bits)
186
1.36k
}
187
188
/// Returns the next representable floating-point value smaller than the input value.
189
///
190
/// This function takes a floating-point value that implements the FloatBits trait,
191
/// calculates the next representable value smaller than the input, and returns it.
192
///
193
/// If the input value is NaN or negative infinity, the function returns the input value.
194
///
195
/// # Examples
196
///
197
/// ```
198
/// use datafusion_common::rounding::next_down;
199
///
200
/// let f: f32 = 1.0;
201
/// let next_f = next_down(f);
202
/// assert_eq!(next_f, 0.99999994);
203
/// ```
204
0
pub fn next_down<F: FloatBits + Copy>(float: F) -> F {
205
0
    let bits = float.to_bits();
206
0
    if float.float_is_nan() || bits == F::neg_infinity().to_bits() {
207
0
        return float;
208
0
    }
209
0
    let abs = bits & F::CLEAR_SIGN_MASK;
210
0
    let next_bits = if abs == F::ZERO {
211
0
        F::NEG_TINY_BITS
212
0
    } else if bits == abs {
213
0
        bits - F::ONE
214
    } else {
215
0
        bits + F::ONE
216
    };
217
0
    F::from_bits(next_bits)
218
0
}
219
220
#[cfg(any(
221
    not(any(target_arch = "x86_64", target_arch = "aarch64")),
222
    target_os = "windows"
223
))]
224
fn alter_fp_rounding_mode_conservative<const UPPER: bool, F>(
225
    lhs: &ScalarValue,
226
    rhs: &ScalarValue,
227
    operation: F,
228
) -> Result<ScalarValue>
229
where
230
    F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
231
{
232
    let mut result = operation(lhs, rhs)?;
233
    match &mut result {
234
        ScalarValue::Float64(Some(value)) => {
235
            if UPPER {
236
                *value = next_up(*value)
237
            } else {
238
                *value = next_down(*value)
239
            }
240
        }
241
        ScalarValue::Float32(Some(value)) => {
242
            if UPPER {
243
                *value = next_up(*value)
244
            } else {
245
                *value = next_down(*value)
246
            }
247
        }
248
        _ => {}
249
    };
250
    Ok(result)
251
}
252
253
16.1k
pub fn alter_fp_rounding_mode<const UPPER: bool, F>(
254
16.1k
    lhs: &ScalarValue,
255
16.1k
    rhs: &ScalarValue,
256
16.1k
    operation: F,
257
16.1k
) -> Result<ScalarValue>
258
16.1k
where
259
16.1k
    F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
260
16.1k
{
261
16.1k
    #[cfg(all(
262
16.1k
        any(target_arch = "x86_64", target_arch = "aarch64"),
263
16.1k
        not(target_os = "windows")
264
16.1k
    ))]
265
16.1k
    unsafe {
266
16.1k
        let current = fegetround();
267
16.1k
        fesetround(if UPPER { 
FE_UPWARD0
} else { FE_DOWNWARD });
268
16.1k
        let result = operation(lhs, rhs);
269
16.1k
        fesetround(current);
270
16.1k
        result
271
16.1k
    }
272
16.1k
    #[cfg(any(
273
16.1k
        not(any(target_arch = "x86_64", target_arch = "aarch64")),
274
16.1k
        target_os = "windows"
275
16.1k
    ))]
276
16.1k
    alter_fp_rounding_mode_conservative::<UPPER, _>(lhs, rhs, operation)
277
16.1k
}
278
279
#[cfg(test)]
280
mod tests {
281
    use super::{next_down, next_up};
282
283
    #[test]
284
    fn test_next_down() {
285
        let x = 1.0f64;
286
        // Clamp value into range [0, 1).
287
        let clamped = x.clamp(0.0, next_down(1.0f64));
288
        assert!(clamped < 1.0);
289
        assert_eq!(next_up(clamped), 1.0);
290
    }
291
292
    #[test]
293
    fn test_next_up_small_positive() {
294
        let value: f64 = 1.0;
295
        let result = next_up(value);
296
        assert_eq!(result, 1.0000000000000002);
297
    }
298
299
    #[test]
300
    fn test_next_up_small_negative() {
301
        let value: f64 = -1.0;
302
        let result = next_up(value);
303
        assert_eq!(result, -0.9999999999999999);
304
    }
305
306
    #[test]
307
    fn test_next_up_pos_infinity() {
308
        let value: f64 = f64::INFINITY;
309
        let result = next_up(value);
310
        assert_eq!(result, f64::INFINITY);
311
    }
312
313
    #[test]
314
    fn test_next_up_nan() {
315
        let value: f64 = f64::NAN;
316
        let result = next_up(value);
317
        assert!(result.is_nan());
318
    }
319
320
    #[test]
321
    fn test_next_down_small_positive() {
322
        let value: f64 = 1.0;
323
        let result = next_down(value);
324
        assert_eq!(result, 0.9999999999999999);
325
    }
326
327
    #[test]
328
    fn test_next_down_small_negative() {
329
        let value: f64 = -1.0;
330
        let result = next_down(value);
331
        assert_eq!(result, -1.0000000000000002);
332
    }
333
334
    #[test]
335
    fn test_next_down_neg_infinity() {
336
        let value: f64 = f64::NEG_INFINITY;
337
        let result = next_down(value);
338
        assert_eq!(result, f64::NEG_INFINITY);
339
    }
340
341
    #[test]
342
    fn test_next_down_nan() {
343
        let value: f64 = f64::NAN;
344
        let result = next_down(value);
345
        assert!(result.is_nan());
346
    }
347
348
    #[test]
349
    fn test_next_up_small_positive_f32() {
350
        let value: f32 = 1.0;
351
        let result = next_up(value);
352
        assert_eq!(result, 1.0000001);
353
    }
354
355
    #[test]
356
    fn test_next_up_small_negative_f32() {
357
        let value: f32 = -1.0;
358
        let result = next_up(value);
359
        assert_eq!(result, -0.99999994);
360
    }
361
362
    #[test]
363
    fn test_next_up_pos_infinity_f32() {
364
        let value: f32 = f32::INFINITY;
365
        let result = next_up(value);
366
        assert_eq!(result, f32::INFINITY);
367
    }
368
369
    #[test]
370
    fn test_next_up_nan_f32() {
371
        let value: f32 = f32::NAN;
372
        let result = next_up(value);
373
        assert!(result.is_nan());
374
    }
375
    #[test]
376
    fn test_next_down_small_positive_f32() {
377
        let value: f32 = 1.0;
378
        let result = next_down(value);
379
        assert_eq!(result, 0.99999994);
380
    }
381
    #[test]
382
    fn test_next_down_small_negative_f32() {
383
        let value: f32 = -1.0;
384
        let result = next_down(value);
385
        assert_eq!(result, -1.0000001);
386
    }
387
    #[test]
388
    fn test_next_down_neg_infinity_f32() {
389
        let value: f32 = f32::NEG_INFINITY;
390
        let result = next_down(value);
391
        assert_eq!(result, f32::NEG_INFINITY);
392
    }
393
    #[test]
394
    fn test_next_down_nan_f32() {
395
        let value: f32 = f32::NAN;
396
        let result = next_down(value);
397
        assert!(result.is_nan());
398
    }
399
}