Skip to content

Commit

Permalink
Fixed decimal bounds
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <[email protected]>

Added tests
  • Loading branch information
razajafri committed Jan 10, 2024
1 parent fab5af2 commit 160a3c5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
25 changes: 18 additions & 7 deletions java/src/main/java/ai/rapids/cudf/DecimalUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ public static ColumnVector lessThan(ColumnView lhs, BigDecimal rhs) {
int leftScale = lhs.getType().getScale();
int leftPrecision = lhs.getType().getDecimalMaxPrecision();

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// less than and it is rhs that we are rounding, we will round away from 0 (UP)
// to make sure we always return the correct value.
// First we have to round the scalar (rhs) to the same scale as lhs.
// Because this is a less-than, and it is rhs that we are rounding, we will round away from 0 (UP) in case rhs is
// positive and toward 0 (DOWN) if rhs is negative to make sure we always return the correct value.
// For example:
// 100.1 < 100.19
// If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1
BigDecimal roundedRhs = rhs.setScale(-leftScale, BigDecimal.ROUND_UP);
// ex:1 100.1 < 100.19
// ex:2 -100.2 < -100.19
// In ex:1 If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1
// In ex:2 If we rounded up the rhs -100.19 would become -100.2, and now -100.2 is not < -100.2
BigDecimal roundedRhs = rhs.setScale(-leftScale, rhs.signum() > 0 ? BigDecimal.ROUND_UP : BigDecimal.ROUND_DOWN);

if (roundedRhs.precision() > leftPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
Expand Down Expand Up @@ -142,7 +144,16 @@ public static ColumnVector greaterThan(ColumnView lhs, BigDecimal rhs) {
// For example:
// 100.2 > 100.19
// If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2
BigDecimal roundedRhs = rhs.setScale(-cvScale, BigDecimal.ROUND_DOWN);

// First we have to round the scalar (rhs) to the same scale as lhs.
// Because this is a greater-than, and it is rhs that we are rounding, we will round towards 0 (DOWN) in case rhs is
// positive and away from 0 (UP) if rhs is negative to make sure we always return the correct value.
// For example:
// ex:1 100.2 > 100.19
// ex:2 -100.1 > -100.19
// In ex:1 If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2
// In ex:2 If we rounded down the rhs -100.19 would become -100.1, and now -100.1 is not > -100.1
BigDecimal roundedRhs = rhs.setScale(-cvScale, rhs.signum() > 0 ? BigDecimal.ROUND_DOWN : BigDecimal.ROUND_UP);

if (roundedRhs.precision() > maxPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
Expand Down
40 changes: 40 additions & 0 deletions java/src/test/java/ai/rapids/cudf/DecimalUtilsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;

public class DecimalUtilsTest extends CudfTestBase {
@Test
public void testOutOfBounds() {
try (ColumnView cv = ColumnVector.fromDecimals(
new BigDecimal("-1E+3"),
new BigDecimal("1E+3"),
new BigDecimal("9E+1"),
new BigDecimal("-9E+1"),
new BigDecimal("-91"));
ColumnView expected = ColumnVector.fromBooleans(true, true, false, false, true)) {
ColumnView result = DecimalUtils.outOfBounds(cv, 1, -1);
assertColumnsAreEqual(expected, result);
}
}
}

0 comments on commit 160a3c5

Please sign in to comment.