Skip to content

Commit

Permalink
JNI: Push back decimal utils from spark-rapids (#9907)
Browse files Browse the repository at this point in the history
Current PR is to push back cuDF-related decimal utilities from spark-rapids (for NVIDIA/spark-rapids#3793).

These utils were manually verified through spark-rapids integration tests.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)
  - Jason Lowe (https://github.com/jlowe)

URL: #9907
  • Loading branch information
sperlingxx authored Feb 22, 2022
1 parent 7a17f28 commit 36e8825
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 4 deletions.
31 changes: 31 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DType.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ private DType(DTypeEnum id, int decimalScale) {
STRUCT
};

/**
* Returns max precision for Decimal Type.
* @return max precision this Decimal Type can hold
*/
public int getDecimalMaxPrecision() {
if (!isDecimalType()) {
throw new IllegalArgumentException("not a decimal type: " + this);
}
if (typeId == DTypeEnum.DECIMAL32) return DECIMAL32_MAX_PRECISION;
if (typeId == DTypeEnum.DECIMAL64) return DECIMAL64_MAX_PRECISION;
return DType.DECIMAL128_MAX_PRECISION;
}

/**
* Get the number of decimal places needed to hold the Integral Type.
* NOTE: this method is NOT for Decimal Type but for Integral Type.
* @return the minimum decimal precision (places) for Integral Type
*/
public int getPrecisionForInt() {
// -128 to 127
if (typeId == DTypeEnum.INT8) return 3;
// -32768 to 32767
if (typeId == DTypeEnum.INT16) return 5;
// -2147483648 to 2147483647
if (typeId == DTypeEnum.INT32) return 10;
// -9223372036854775808 to 9223372036854775807
if (typeId == DTypeEnum.INT64) return 19;

throw new IllegalArgumentException("not an integral type: " + this);
}

/**
* This only works for fixed width types. Variable width types like strings the value is
* undefined and should be ignored.
Expand Down
164 changes: 164 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DecimalUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

import java.math.BigDecimal;
import java.util.AbstractMap;
import java.util.Map;

public class DecimalUtils {

/**
* Creates a cuDF decimal type with precision and scale
*/
public static DType createDecimalType(int precision, int scale) {
if (precision <= DType.DECIMAL32_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL32, -scale);
} else if (precision <= DType.DECIMAL64_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL64, -scale);
} else if (precision <= DType.DECIMAL128_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL128, -scale);
}
throw new IllegalArgumentException("precision overflow: " + precision);
}

/**
* Given decimal precision and scale, returns the lower and upper bound of current decimal type.
*
* Be very careful when comparing these CUDF decimal comparisons really only work
* when both types are already the same precision and scale, and when you change the scale
* you end up losing information.
* @param precision the max precision of decimal type
* @param scale the scale of decimal type
* @return a Map Entry of BigDecimal, lower bound as the key, upper bound as the value
*/
public static Map.Entry<BigDecimal, BigDecimal> bounds(int precision, int scale) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < precision; i++) sb.append("9");
sb.append("e");
sb.append(-scale);
String boundStr = sb.toString();
BigDecimal upperBound = new BigDecimal(boundStr);
BigDecimal lowerBound = new BigDecimal("-" + boundStr);
return new AbstractMap.SimpleImmutableEntry<>(lowerBound, upperBound);
}

/**
* With precision and scale, checks each value of input decimal column for out of bound.
* @return the boolean column represents whether specific values are out of bound or not
*/
public static ColumnVector outOfBounds(ColumnView input, int precision, int scale) {
Map.Entry<BigDecimal, BigDecimal> boundPair = bounds(precision, scale);
BigDecimal lowerBound = boundPair.getKey();
BigDecimal upperBound = boundPair.getValue();
try (ColumnVector over = greaterThan(input, upperBound);
ColumnVector under = lessThan(input, lowerBound)) {
return over.or(under);
}
}

/**
* Because the native lessThan operator has issues with comparing decimal values that have different
* precision and scale accurately. This method takes some special steps to get rid of these issues.
*/
public static ColumnVector lessThan(ColumnView lhs, BigDecimal rhs) {
assert (lhs.getType().isDecimalType());
int leftScale = lhs.getType().getScale();
int leftPrecision = lhs.getType().getDecimalMaxPrecision();

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// less than and it is rhs that we are rounding, we will round away from 0 (UP)
// to make sure we always return the correct value.
// For example:
// 100.1 < 100.19
// If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1
BigDecimal roundedRhs = rhs.setScale(-leftScale, BigDecimal.ROUND_UP);

if (roundedRhs.precision() > leftPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
// the scale is the same so we can still figure this out. For example if LHS precision is
// 4 and RHS precision is 5 we get the following...
// 9999 < 99999 => true
// -9999 < 99999 => true
// 9999 < -99999 => false
// -9999 < -99999 => false
// so the result should be the same as RHS > 0
try (Scalar isPositive = Scalar.fromBool(roundedRhs.compareTo(BigDecimal.ZERO) > 0)) {
return ColumnVector.fromScalar(isPositive, (int) lhs.getRowCount());
}
}
try (Scalar scalarRhs = Scalar.fromDecimal(roundedRhs.unscaledValue(), lhs.getType())) {
return lhs.lessThan(scalarRhs);
}
}

/**
* Because the native lessThan operator has issues with comparing decimal values that have different
* precision and scale accurately. This method takes some special steps to get rid of these issues.
*/
public static ColumnVector lessThan(BinaryOperable lhs, BigDecimal rhs, int numRows) {
if (lhs instanceof ColumnView) {
return lessThan((ColumnView) lhs, rhs);
}
Scalar scalarLhs = (Scalar) lhs;
if (scalarLhs.isValid()) {
try (Scalar isLess = Scalar.fromBool(scalarLhs.getBigDecimal().compareTo(rhs) < 0)) {
return ColumnVector.fromScalar(isLess, numRows);
}
}
try (Scalar nullScalar = Scalar.fromNull(DType.BOOL8)) {
return ColumnVector.fromScalar(nullScalar, numRows);
}
}

/**
* Because the native greaterThan operator has issues with comparing decimal values that have different
* precision and scale accurately. This method takes some special steps to get rid of these issues.
*/
public static ColumnVector greaterThan(ColumnView lhs, BigDecimal rhs) {
assert (lhs.getType().isDecimalType());
int cvScale = lhs.getType().getScale();
int maxPrecision = lhs.getType().getDecimalMaxPrecision();

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// greater than and it is rhs that we are rounding, we will round towards 0 (DOWN)
// to make sure we always return the correct value.
// For example:
// 100.2 > 100.19
// If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2
BigDecimal roundedRhs = rhs.setScale(-cvScale, BigDecimal.ROUND_DOWN);

if (roundedRhs.precision() > maxPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
// the scale is the same so we can still figure this out. For example if LHS precision is
// 4 and RHS precision is 5 we get the following...
// 9999 > 99999 => false
// -9999 > 99999 => false
// 9999 > -99999 => true
// -9999 > -99999 => true
// so the result should be the same as RHS < 0
try (Scalar isNegative = Scalar.fromBool(roundedRhs.compareTo(BigDecimal.ZERO) < 0)) {
return ColumnVector.fromScalar(isNegative, (int) lhs.getRowCount());
}
}
try (Scalar scalarRhs = Scalar.fromDecimal(roundedRhs.unscaledValue(), lhs.getType())) {
return lhs.greaterThan(scalarRhs);
}
}
}
15 changes: 11 additions & 4 deletions java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,22 @@ public static Scalar fromDecimal(BigDecimal value) {
return Scalar.fromNull(DType.create(DType.DTypeEnum.DECIMAL64, 0));
}
DType dt = DType.fromJavaBigDecimal(value);
return fromDecimal(value.unscaledValue(), dt);
}

public static Scalar fromDecimal(BigInteger unscaledValue, DType dt) {
if (unscaledValue == null) {
return Scalar.fromNull(dt);
}
long handle;
if (dt.typeId == DType.DTypeEnum.DECIMAL32) {
handle = makeDecimal32Scalar(value.unscaledValue().intValueExact(), -value.scale(), true);
handle = makeDecimal32Scalar(unscaledValue.intValueExact(), dt.getScale(), true);
} else if (dt.typeId == DType.DTypeEnum.DECIMAL64) {
handle = makeDecimal64Scalar(value.unscaledValue().longValueExact(), -value.scale(), true);
handle = makeDecimal64Scalar(unscaledValue.longValueExact(), dt.getScale(), true);
} else {
byte[] unscaledValueBytes = value.unscaledValue().toByteArray();
byte[] unscaledValueBytes = unscaledValue.toByteArray();
byte[] finalBytes = convertDecimal128FromJavaToCudf(unscaledValueBytes);
handle = makeDecimal128Scalar(finalBytes, -value.scale(), true);
handle = makeDecimal128Scalar(finalBytes, dt.getScale(), true);
}
return new Scalar(dt, handle);
}
Expand Down

0 comments on commit 36e8825

Please sign in to comment.