diff --git a/java/src/main/java/ai/rapids/cudf/DType.java b/java/src/main/java/ai/rapids/cudf/DType.java index 2e5b0202dc5..d0bb7761da4 100644 --- a/java/src/main/java/ai/rapids/cudf/DType.java +++ b/java/src/main/java/ai/rapids/cudf/DType.java @@ -173,6 +173,37 @@ private DType(DTypeEnum id, int decimalScale) { STRUCT }; + /** + * Returns max precision for Decimal Type. + * @return max precision this Decimal Type can hold + */ + public int getDecimalMaxPrecision() { + if (!isDecimalType()) { + throw new IllegalArgumentException("not a decimal type: " + this); + } + if (typeId == DTypeEnum.DECIMAL32) return DECIMAL32_MAX_PRECISION; + if (typeId == DTypeEnum.DECIMAL64) return DECIMAL64_MAX_PRECISION; + return DType.DECIMAL128_MAX_PRECISION; + } + + /** + * Get the number of decimal places needed to hold the Integral Type. + * NOTE: this method is NOT for Decimal Type but for Integral Type. + * @return the minimum decimal precision (places) for Integral Type + */ + public int getPrecisionForInt() { + // -128 to 127 + if (typeId == DTypeEnum.INT8) return 3; + // -32768 to 32767 + if (typeId == DTypeEnum.INT16) return 5; + // -2147483648 to 2147483647 + if (typeId == DTypeEnum.INT32) return 10; + // -9223372036854775808 to 9223372036854775807 + if (typeId == DTypeEnum.INT64) return 19; + + throw new IllegalArgumentException("not an integral type: " + this); + } + /** * This only works for fixed width types. Variable width types like strings the value is * undefined and should be ignored. diff --git a/java/src/main/java/ai/rapids/cudf/DecimalUtils.java b/java/src/main/java/ai/rapids/cudf/DecimalUtils.java new file mode 100644 index 00000000000..1979bd1bd5b --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/DecimalUtils.java @@ -0,0 +1,164 @@ +/* + * + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import java.math.BigDecimal; +import java.util.AbstractMap; +import java.util.Map; + +public class DecimalUtils { + + /** + * Creates a cuDF decimal type with precision and scale + */ + public static DType createDecimalType(int precision, int scale) { + if (precision <= DType.DECIMAL32_MAX_PRECISION) { + return DType.create(DType.DTypeEnum.DECIMAL32, -scale); + } else if (precision <= DType.DECIMAL64_MAX_PRECISION) { + return DType.create(DType.DTypeEnum.DECIMAL64, -scale); + } else if (precision <= DType.DECIMAL128_MAX_PRECISION) { + return DType.create(DType.DTypeEnum.DECIMAL128, -scale); + } + throw new IllegalArgumentException("precision overflow: " + precision); + } + + /** + * Given decimal precision and scale, returns the lower and upper bound of current decimal type. + * + * Be very careful when comparing these CUDF decimal comparisons really only work + * when both types are already the same precision and scale, and when you change the scale + * you end up losing information. + * @param precision the max precision of decimal type + * @param scale the scale of decimal type + * @return a Map Entry of BigDecimal, lower bound as the key, upper bound as the value + */ + public static Map.Entry bounds(int precision, int scale) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < precision; i++) sb.append("9"); + sb.append("e"); + sb.append(-scale); + String boundStr = sb.toString(); + BigDecimal upperBound = new BigDecimal(boundStr); + BigDecimal lowerBound = new BigDecimal("-" + boundStr); + return new AbstractMap.SimpleImmutableEntry<>(lowerBound, upperBound); + } + + /** + * With precision and scale, checks each value of input decimal column for out of bound. + * @return the boolean column represents whether specific values are out of bound or not + */ + public static ColumnVector outOfBounds(ColumnView input, int precision, int scale) { + Map.Entry boundPair = bounds(precision, scale); + BigDecimal lowerBound = boundPair.getKey(); + BigDecimal upperBound = boundPair.getValue(); + try (ColumnVector over = greaterThan(input, upperBound); + ColumnVector under = lessThan(input, lowerBound)) { + return over.or(under); + } + } + + /** + * Because the native lessThan operator has issues with comparing decimal values that have different + * precision and scale accurately. This method takes some special steps to get rid of these issues. + */ + public static ColumnVector lessThan(ColumnView lhs, BigDecimal rhs) { + assert (lhs.getType().isDecimalType()); + int leftScale = lhs.getType().getScale(); + int leftPrecision = lhs.getType().getDecimalMaxPrecision(); + + // First we have to round the scalar (rhs) to the same scale as lhs. Because this is a + // less than and it is rhs that we are rounding, we will round away from 0 (UP) + // to make sure we always return the correct value. + // For example: + // 100.1 < 100.19 + // If we rounded down the rhs 100.19 would become 100.1, and now 100.1 is not < 100.1 + BigDecimal roundedRhs = rhs.setScale(-leftScale, BigDecimal.ROUND_UP); + + if (roundedRhs.precision() > leftPrecision) { + // converting rhs to the same precision as lhs would result in an overflow/error, but + // the scale is the same so we can still figure this out. For example if LHS precision is + // 4 and RHS precision is 5 we get the following... + // 9999 < 99999 => true + // -9999 < 99999 => true + // 9999 < -99999 => false + // -9999 < -99999 => false + // so the result should be the same as RHS > 0 + try (Scalar isPositive = Scalar.fromBool(roundedRhs.compareTo(BigDecimal.ZERO) > 0)) { + return ColumnVector.fromScalar(isPositive, (int) lhs.getRowCount()); + } + } + try (Scalar scalarRhs = Scalar.fromDecimal(roundedRhs.unscaledValue(), lhs.getType())) { + return lhs.lessThan(scalarRhs); + } + } + + /** + * Because the native lessThan operator has issues with comparing decimal values that have different + * precision and scale accurately. This method takes some special steps to get rid of these issues. + */ + public static ColumnVector lessThan(BinaryOperable lhs, BigDecimal rhs, int numRows) { + if (lhs instanceof ColumnView) { + return lessThan((ColumnView) lhs, rhs); + } + Scalar scalarLhs = (Scalar) lhs; + if (scalarLhs.isValid()) { + try (Scalar isLess = Scalar.fromBool(scalarLhs.getBigDecimal().compareTo(rhs) < 0)) { + return ColumnVector.fromScalar(isLess, numRows); + } + } + try (Scalar nullScalar = Scalar.fromNull(DType.BOOL8)) { + return ColumnVector.fromScalar(nullScalar, numRows); + } + } + + /** + * Because the native greaterThan operator has issues with comparing decimal values that have different + * precision and scale accurately. This method takes some special steps to get rid of these issues. + */ + public static ColumnVector greaterThan(ColumnView lhs, BigDecimal rhs) { + assert (lhs.getType().isDecimalType()); + int cvScale = lhs.getType().getScale(); + int maxPrecision = lhs.getType().getDecimalMaxPrecision(); + + // First we have to round the scalar (rhs) to the same scale as lhs. Because this is a + // greater than and it is rhs that we are rounding, we will round towards 0 (DOWN) + // to make sure we always return the correct value. + // For example: + // 100.2 > 100.19 + // If we rounded up the rhs 100.19 would become 100.2, and now 100.2 is not > 100.2 + BigDecimal roundedRhs = rhs.setScale(-cvScale, BigDecimal.ROUND_DOWN); + + if (roundedRhs.precision() > maxPrecision) { + // converting rhs to the same precision as lhs would result in an overflow/error, but + // the scale is the same so we can still figure this out. For example if LHS precision is + // 4 and RHS precision is 5 we get the following... + // 9999 > 99999 => false + // -9999 > 99999 => false + // 9999 > -99999 => true + // -9999 > -99999 => true + // so the result should be the same as RHS < 0 + try (Scalar isNegative = Scalar.fromBool(roundedRhs.compareTo(BigDecimal.ZERO) < 0)) { + return ColumnVector.fromScalar(isNegative, (int) lhs.getRowCount()); + } + } + try (Scalar scalarRhs = Scalar.fromDecimal(roundedRhs.unscaledValue(), lhs.getType())) { + return lhs.greaterThan(scalarRhs); + } + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 03e77573695..205efadfe6c 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -261,15 +261,22 @@ public static Scalar fromDecimal(BigDecimal value) { return Scalar.fromNull(DType.create(DType.DTypeEnum.DECIMAL64, 0)); } DType dt = DType.fromJavaBigDecimal(value); + return fromDecimal(value.unscaledValue(), dt); + } + + public static Scalar fromDecimal(BigInteger unscaledValue, DType dt) { + if (unscaledValue == null) { + return Scalar.fromNull(dt); + } long handle; if (dt.typeId == DType.DTypeEnum.DECIMAL32) { - handle = makeDecimal32Scalar(value.unscaledValue().intValueExact(), -value.scale(), true); + handle = makeDecimal32Scalar(unscaledValue.intValueExact(), dt.getScale(), true); } else if (dt.typeId == DType.DTypeEnum.DECIMAL64) { - handle = makeDecimal64Scalar(value.unscaledValue().longValueExact(), -value.scale(), true); + handle = makeDecimal64Scalar(unscaledValue.longValueExact(), dt.getScale(), true); } else { - byte[] unscaledValueBytes = value.unscaledValue().toByteArray(); + byte[] unscaledValueBytes = unscaledValue.toByteArray(); byte[] finalBytes = convertDecimal128FromJavaToCudf(unscaledValueBytes); - handle = makeDecimal128Scalar(finalBytes, -value.scale(), true); + handle = makeDecimal128Scalar(finalBytes, dt.getScale(), true); } return new Scalar(dt, handle); }