Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Binary operations support for decimal type in cudf Java [skip ci] #6734

Merged
merged 9 commits into from
Nov 13, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
- PR #6727 Remove 2nd type-dispatcher call from cudf::reduce
- PR #6749 Update nested JNI builder so we can do it incrementally
- PR #6748 Add Java API to concatenate serialized tables to ContiguousTable
- PR #6734 Binary operations support for decimal type in cudf Java

## Bug Fixes

Expand Down
32 changes: 31 additions & 1 deletion java/src/main/java/ai/rapids/cudf/BinaryOperable.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public interface BinaryOperable {
* <p>
* BOOL8 is treated like an INT8. Math on boolean operations makes little sense. If
* you want to stay as a BOOL8 you will need to explicitly specify the output type.
* For decimal types, DECIMAL32 and DECIMAL64 takes in another parameter `scale`. DType is created
* with scale=0 as scale is required. Dtype is discarded for binary operations for decimal
* types in cudf as a new DType is created for output type with the new scale.
*/
static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
DType a = lhs.getType();
Expand Down Expand Up @@ -76,6 +79,25 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
if (a == DType.BOOL8 || b == DType.BOOL8) {
return DType.BOOL8;
}
if (a.isDecimalType() && b.isDecimalType()) {
// Here scale is created with value 0 as `scale` is required to create DType of
// decimal type. Dtype is discarded for binary operations for decimal types in cudf as a new
// DType is created for output type with new scale. New scale for output depends upon operator.
int scale = 0;
revans2 marked this conversation as resolved.
Show resolved Hide resolved
if (a.typeId == DType.DTypeEnum.DECIMAL32) {
if (b.typeId == DType.DTypeEnum.DECIMAL32) {
return DType.create(DType.DTypeEnum.DECIMAL32, scale);
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
} else if (a.typeId == DType.DTypeEnum.DECIMAL64) {
if (b.typeId == DType.DTypeEnum.DECIMAL64) {
return DType.create(DType.DTypeEnum.DECIMAL64, scale);
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
}
}
throw new IllegalArgumentException("Unsupported types " + a + " and " + b);
}

Expand All @@ -94,7 +116,9 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType);

/**
* Add + operator. this + rhs
* Add one vector to another with the given output type. this + rhs
* Output type is ignored for the operations between decimal types and
* it is always decimal type.
*/
default ColumnVector add(BinaryOperable rhs, DType outType) {
return binaryOp(BinaryOp.ADD, rhs, outType);
Expand All @@ -109,6 +133,8 @@ default ColumnVector add(BinaryOperable rhs) {

/**
* Subtract one vector from another with the given output type. this - rhs
* Output type is ignored for the operations between decimal types and
* it is always decimal type.
*/
default ColumnVector sub(BinaryOperable rhs, DType outType) {
return binaryOp(BinaryOp.SUB, rhs, outType);
Expand All @@ -123,6 +149,8 @@ default ColumnVector sub(BinaryOperable rhs) {

/**
* Multiply two vectors together with the given output type. this * rhs
* Output type is ignored for the operations between decimal types and
* it is always decimal type.
*/
default ColumnVector mul(BinaryOperable rhs, DType outType) {
return binaryOp(BinaryOp.MUL, rhs, outType);
Expand All @@ -137,6 +165,8 @@ default ColumnVector mul(BinaryOperable rhs) {

/**
* Divide one vector by another with the given output type. this / rhs
* Output type is ignored for the operations between decimal types and
* it is always decimal type.
*/
default ColumnVector div(BinaryOperable rhs, DType outType) {
return binaryOp(BinaryOp.DIV, rhs, outType);
Expand Down
6 changes: 4 additions & 2 deletions java/src/main/java/ai/rapids/cudf/HostColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ public static HostColumnVector fromDecimals(BigDecimal... values) {
BigDecimal maxDec = Arrays.stream(values).filter(Objects::nonNull)
.max(Comparator.comparingInt(BigDecimal::precision))
.orElse(BigDecimal.ZERO);
int maxScale = Arrays.stream(values)
.map(decimal -> (decimal == null) ? 0 : decimal.scale())
int maxScale = Arrays.stream(values).filter(Objects::nonNull)
revans2 marked this conversation as resolved.
Show resolved Hide resolved
.map(decimal -> decimal.scale())
.max(Comparator.naturalOrder())
.orElse(0);
maxDec = maxDec.setScale(maxScale, RoundingMode.UNNECESSARY);
Expand Down Expand Up @@ -1364,8 +1364,10 @@ public final Builder append(BigDecimal value, RoundingMode roundingMode) {
assert currentIndex < rows;
BigInteger unscaledValue = value.setScale(-type.getScale(), roundingMode).unscaledValue();
if (type.typeId == DType.DTypeEnum.DECIMAL32) {
assert value.precision() <= DType.DECIMAL32_MAX_PRECISION : "value exceeds maximum precision for DECIMAL32";
data.setInt(currentIndex * type.getSizeInBytes(), unscaledValue.intValueExact());
} else if (type.typeId == DType.DTypeEnum.DECIMAL64) {
assert value.precision() <= DType.DECIMAL64_MAX_PRECISION : "value exceeds maximum precision for DECIMAL64 ";
data.setLong(currentIndex * type.getSizeInBytes(), unscaledValue.longValueExact());
} else {
throw new IllegalStateException(type + " is not a supported decimal type.");
Expand Down
1 change: 1 addition & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.rapids.cudf.HostColumnVector.StructType;

import java.io.File;
import java.math.BigDecimal;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the only change to this file is an uneeded import now.

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down
Loading