Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI: Push back decimal utils from spark-rapids #9907

Merged
merged 8 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DType.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,37 @@ private DType(DTypeEnum id, int decimalScale) {
STRUCT
};

/**
* Returns max precision for Decimal Type.
* @return max precision this Decimal Type can hold
*/
public int getDecimalMaxPrecision() {
if (!isDecimalType()) {
throw new IllegalArgumentException("not a decimal type: " + this);
}
if (typeId == DTypeEnum.DECIMAL32) return DECIMAL32_MAX_PRECISION;
if (typeId == DTypeEnum.DECIMAL64) return DECIMAL64_MAX_PRECISION;
return DType.DECIMAL128_MAX_PRECISION;
}

/**
* Get the number of decimal places needed to hold the Integral Type.
* NOTE: this method is NOT for Decimal Type but for Integral Type.
* @return the minimum decimal precision (places) for Integral Type
*/
public int getPrecisionForInt() {
// -128 to 127
if (typeId == DTypeEnum.INT8) return 3;
// -32768 to 32767
if (typeId == DTypeEnum.INT16) return 5;
// -2147483648 to 2147483647
if (typeId == DTypeEnum.INT32) return 10;
// -9223372036854775808 to 9223372036854775807
if (typeId == DTypeEnum.INT64) return 19;

throw new IllegalArgumentException("not an integral type: " + this);
}

/**
* This only works for fixed width types. Variable width types like strings the value is
* undefined and should be ignored.
Expand Down
164 changes: 164 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DecimalUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

import java.math.BigDecimal;
import java.util.AbstractMap;
import java.util.Map;

public class DecimalUtils {

/**
* Creates a cuDF decimal type with precision and scale
*/
public static DType createDecimalType(int precision, int scale) {
if (precision <= DType.DECIMAL32_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL32, -scale);
} else if (precision <= DType.DECIMAL64_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL64, -scale);
} else if (precision <= DType.DECIMAL128_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL128, -scale);
}
throw new IllegalArgumentException("precision overflow: " + precision);
}

/**
* Given decimal precision and scale, returns the lower and upper bound of current decimal type.
*
* Be very careful when comparing these CUDF decimal comparisons really only work
* when both types are already the same precision and scale, and when you change the scale
* you end up losing information.
* @param precision the max precision of decimal type
* @param scale the scale of decimal type
* @return a Map Entry of BigDecimal, lower bound as the key, upper bound as the value
*/
public static Map.Entry<BigDecimal, BigDecimal> bounds(int precision, int scale) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < precision; i++) sb.append("9");
sb.append("e");
sb.append(-scale);
String boundStr = sb.toString();
BigDecimal upperBound = new BigDecimal(boundStr);
BigDecimal lowerBound = new BigDecimal("-" + boundStr);
return new AbstractMap.SimpleImmutableEntry<>(lowerBound, upperBound);
}

/**
* With precision and scale, checks each value of input decimal column for out of bound.
* @return the boolean column represents whether specific values are out of bound or not
*/
public static ColumnVector outOfBounds(ColumnView input, int precision, int scale) {
Map.Entry<BigDecimal, BigDecimal> boundPair = bounds(precision, scale);
BigDecimal lowerBound = boundPair.getKey();
BigDecimal upperBound = boundPair.getValue();
try (ColumnVector over = greaterThan(input, upperBound);
ColumnVector under = lessThan(input, lowerBound)) {
return over.or(under);
}
}

/**
* Because the native lessThan operator has issues with comparing decimal values that have different
* precision and scale accurately. This method takes some special steps to get rid of these issues.
*/
public static ColumnVector lessThan(ColumnView lhs, BigDecimal rhs) {
assert (lhs.getType().isDecimalType());
int leftScale = lhs.getType().getScale();
int leftPrecision = lhs.getType().getDecimalMaxPrecision();

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// less than and it is rhs that we are rounding, we will round away from 0 (UP)
// to make sure we always return the correct value.
// For example:
// 100.1 < 100.19
// If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1
BigDecimal roundedRhs = rhs.setScale(-leftScale, BigDecimal.ROUND_UP);

if (roundedRhs.precision() > leftPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
// the scale is the same so we can still figure this out. For example if LHS precision is
// 4 and RHS precision is 5 we get the following...
// 9999 < 99999 => true
// -9999 < 99999 => true
// 9999 < -99999 => false
// -9999 < -99999 => false
// so the result should be the same as RHS > 0
try (Scalar isPositive = Scalar.fromBool(roundedRhs.compareTo(BigDecimal.ZERO) > 0)) {
return ColumnVector.fromScalar(isPositive, (int) lhs.getRowCount());
}
}
try (Scalar scalarRhs = Scalar.fromDecimal(roundedRhs.unscaledValue(), lhs.getType())) {
return lhs.lessThan(scalarRhs);
}
}

/**
* Because the native lessThan operator has issues with comparing decimal values that have different
* precision and scale accurately. This method takes some special steps to get rid of these issues.
*/
public static ColumnVector lessThan(BinaryOperable lhs, BigDecimal rhs, int numRows) {
if (lhs instanceof ColumnView) {
return lessThan((ColumnView) lhs, rhs);
}
Scalar scalarLhs = (Scalar) lhs;
if (scalarLhs.isValid()) {
try (Scalar isLess = Scalar.fromBool(scalarLhs.getBigDecimal().compareTo(rhs) < 0)) {
return ColumnVector.fromScalar(isLess, numRows);
}
}
try (Scalar nullScalar = Scalar.fromNull(DType.BOOL8)) {
return ColumnVector.fromScalar(nullScalar, numRows);
}
}

/**
* Because the native greaterThan operator has issues with comparing decimal values that have different
* precision and scale accurately. This method takes some special steps to get rid of these issues.
*/
public static ColumnVector greaterThan(ColumnView lhs, BigDecimal rhs) {
assert (lhs.getType().isDecimalType());
int cvScale = lhs.getType().getScale();
int maxPrecision = lhs.getType().getDecimalMaxPrecision();

// First we have to round the scalar (rhs) to the same scale as lhs. Because this is a
// greater than and it is rhs that we are rounding, we will round towards 0 (DOWN)
// to make sure we always return the correct value.
// For example:
// 100.2 > 100.19
// If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2
BigDecimal roundedRhs = rhs.setScale(-cvScale, BigDecimal.ROUND_DOWN);

if (roundedRhs.precision() > maxPrecision) {
// converting rhs to the same precision as lhs would result in an overflow/error, but
// the scale is the same so we can still figure this out. For example if LHS precision is
// 4 and RHS precision is 5 we get the following...
// 9999 > 99999 => false
// -9999 > 99999 => false
// 9999 > -99999 => true
// -9999 > -99999 => true
// so the result should be the same as RHS < 0
try (Scalar isNegative = Scalar.fromBool(roundedRhs.compareTo(BigDecimal.ZERO) < 0)) {
return ColumnVector.fromScalar(isNegative, (int) lhs.getRowCount());
}
}
try (Scalar scalarRhs = Scalar.fromDecimal(roundedRhs.unscaledValue(), lhs.getType())) {
return lhs.greaterThan(scalarRhs);
}
}
}
15 changes: 11 additions & 4 deletions java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,22 @@ public static Scalar fromDecimal(BigDecimal value) {
return Scalar.fromNull(DType.create(DType.DTypeEnum.DECIMAL64, 0));
}
DType dt = DType.fromJavaBigDecimal(value);
return fromDecimal(value.unscaledValue(), dt);
}

public static Scalar fromDecimal(BigInteger unscaledValue, DType dt) {
if (unscaledValue == null) {
return Scalar.fromNull(dt);
}
long handle;
if (dt.typeId == DType.DTypeEnum.DECIMAL32) {
handle = makeDecimal32Scalar(value.unscaledValue().intValueExact(), -value.scale(), true);
handle = makeDecimal32Scalar(unscaledValue.intValueExact(), dt.getScale(), true);
} else if (dt.typeId == DType.DTypeEnum.DECIMAL64) {
handle = makeDecimal64Scalar(value.unscaledValue().longValueExact(), -value.scale(), true);
handle = makeDecimal64Scalar(unscaledValue.longValueExact(), dt.getScale(), true);
} else {
byte[] unscaledValueBytes = value.unscaledValue().toByteArray();
byte[] unscaledValueBytes = unscaledValue.toByteArray();
byte[] finalBytes = convertDecimal128FromJavaToCudf(unscaledValueBytes);
handle = makeDecimal128Scalar(finalBytes, -value.scale(), true);
handle = makeDecimal128Scalar(finalBytes, dt.getScale(), true);
}
return new Scalar(dt, handle);
}
Expand Down