From 27b106f832999afa5b3353aaa2adcdb695fb4a47 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 11 Jan 2024 18:32:19 -0800 Subject: [PATCH] [Java] Choose The Correct RoundingMode For Checking Decimal OutOfBounds (#14731) This PR fixes an error in the `outOfBounds` method in which the `RoundingMode` was selected based on positive values only. The RHS should be rounded towards positive infinity (ROUND_CEILING) for the lower bound and towards negative infinity (ROUND_FLOOR) for the upper bound closes #14732 Authors: - Raza Jafri (https://github.com/razajafri) Approvers: - Jason Lowe (https://github.com/jlowe) - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/14731 --- .../java/ai/rapids/cudf/DecimalUtils.java | 30 +++++++------- .../java/ai/rapids/cudf/DecimalUtilsTest.java | 40 +++++++++++++++++++ 2 files changed, 55 insertions(+), 15 deletions(-) create mode 100644 java/src/test/java/ai/rapids/cudf/DecimalUtilsTest.java diff --git a/java/src/main/java/ai/rapids/cudf/DecimalUtils.java b/java/src/main/java/ai/rapids/cudf/DecimalUtils.java index 1979bd1bd5b..7a5be9b08b9 100644 --- a/java/src/main/java/ai/rapids/cudf/DecimalUtils.java +++ b/java/src/main/java/ai/rapids/cudf/DecimalUtils.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -82,13 +82,13 @@ public static ColumnVector lessThan(ColumnView lhs, BigDecimal rhs) { int leftScale = lhs.getType().getScale(); int leftPrecision = lhs.getType().getDecimalMaxPrecision(); - // First we have to round the scalar (rhs) to the same scale as lhs. Because this is a - // less than and it is rhs that we are rounding, we will round away from 0 (UP) - // to make sure we always return the correct value. - // For example: - // 100.1 < 100.19 - // If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1 - BigDecimal roundedRhs = rhs.setScale(-leftScale, BigDecimal.ROUND_UP); + // First we have to round the scalar (rhs) to the same scale as lhs. + // For comparing the two values they should be the same scale, we round the value to positive infinity to maintain + // the relation. Ex: + // 10.2 < 10.29 = true, after rounding rhs to ceiling ===> 10.2 < 10.3 = true, relation is maintained + // 10.3 < 10.29 = false, after rounding rhs to ceiling ===> 10.3 < 10.3 = false, relation is maintained + // 10.1 < 10.10 = false, after rounding rhs to ceiling ===> 10.1 < 10.1 = false, relation is maintained + BigDecimal roundedRhs = rhs.setScale(-leftScale, BigDecimal.ROUND_CEILING); if (roundedRhs.precision() > leftPrecision) { // converting rhs to the same precision as lhs would result in an overflow/error, but @@ -136,13 +136,13 @@ public static ColumnVector greaterThan(ColumnView lhs, BigDecimal rhs) { int cvScale = lhs.getType().getScale(); int maxPrecision = lhs.getType().getDecimalMaxPrecision(); - // First we have to round the scalar (rhs) to the same scale as lhs. Because this is a - // greater than and it is rhs that we are rounding, we will round towards 0 (DOWN) - // to make sure we always return the correct value. - // For example: - // 100.2 > 100.19 - // If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2 - BigDecimal roundedRhs = rhs.setScale(-cvScale, BigDecimal.ROUND_DOWN); + // First we have to round the scalar (rhs) to the same scale as lhs. + // For comparing the two values they should be the same scale, we round the value to negative infinity to maintain + // the relation. Ex: + // 10.3 > 10.29 = true, after rounding rhs to floor ===> 10.3 > 10.2 = true, relation is maintained + // 10.2 > 10.29 = false, after rounding rhs to floor ===> 10.2 > 10.2 = false, relation is maintained + // 10.1 > 10.10 = false, after rounding rhs to floor ===> 10.1 > 10.1 = false, relation is maintained + BigDecimal roundedRhs = rhs.setScale(-cvScale, BigDecimal.ROUND_FLOOR); if (roundedRhs.precision() > maxPrecision) { // converting rhs to the same precision as lhs would result in an overflow/error, but diff --git a/java/src/test/java/ai/rapids/cudf/DecimalUtilsTest.java b/java/src/test/java/ai/rapids/cudf/DecimalUtilsTest.java new file mode 100644 index 00000000000..a96eeda5dd7 --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/DecimalUtilsTest.java @@ -0,0 +1,40 @@ +/* + * + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; + +public class DecimalUtilsTest extends CudfTestBase { + @Test + public void testOutOfBounds() { + try (ColumnView cv = ColumnVector.fromDecimals( + new BigDecimal("-1E+3"), + new BigDecimal("1E+3"), + new BigDecimal("9E+1"), + new BigDecimal("-9E+1"), + new BigDecimal("-91")); + ColumnView expected = ColumnVector.fromBooleans(true, true, false, false, true); + ColumnView result = DecimalUtils.outOfBounds(cv, 1, -1)) { + assertColumnsAreEqual(expected, result); + } + } +}