Skip to content

Commit

Permalink
Add Java test
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Jul 20, 2024
1 parent 0d22af8 commit 30b80e0
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.math.BigInteger;

import static ai.rapids.cudf.AssertUtils.*;

Expand Down Expand Up @@ -651,17 +652,40 @@ void floatingPointToDecimalTest() {
try (
ColumnVector input1 = ColumnVector.fromDoubles(3527.61953125);
ColumnVector input2 = ColumnVector.fromDoubles(9.95);
ColumnVector expected1 = ColumnVector.fromDecimals(BigDecimal.valueOf(35276195313L, 7));
ColumnVector expected2 = ColumnVector.fromDecimals(BigDecimal.valueOf(100L, 1))
ColumnVector input3 = ColumnVector.fromDoubles(10.3);
ColumnVector input4 = ColumnVector.fromDoubles(-10000000.0, -100000.0, 1.0, 100.0, 1000.0);
ColumnVector input5 = ColumnVector.fromDoubles(-10000000.0, 1.0, Double.NaN, -2.0, Double.NEGATIVE_INFINITY);

ColumnVector expected1 = ColumnVector.decimalFromLongs(-7, 35276195313L);
ColumnVector expected2 = ColumnVector.decimalFromInts(-1, 100);
ColumnVector expected3 = ColumnVector.decimalFromBigInt(-1, new BigInteger("103"));
ColumnVector expected4 = ColumnVector.decimalFromBoxedInts(-1, null, null, 10, 1000, null);
ColumnVector expected5 = ColumnVector.decimalFromBoxedLongs(-1, null, 10L, null, -20L, null)
) {
DecimalUtils.CastFloatToDecimalResult output1 = DecimalUtils.floatingPointToDecimal(input1, DType.create(DType.DTypeEnum.DECIMAL64, -7), 12);
DecimalUtils.CastFloatToDecimalResult output2 = DecimalUtils.floatingPointToDecimal(input2, DType.create(DType.DTypeEnum.DECIMAL32, -1), 3);
DecimalUtils.CastFloatToDecimalResult output3 = DecimalUtils.floatingPointToDecimal(input3, DType.create(DType.DTypeEnum.DECIMAL128, -1), 18);
DecimalUtils.CastFloatToDecimalResult output4 = DecimalUtils.floatingPointToDecimal(input4, DType.create(DType.DTypeEnum.DECIMAL32, -1), 4);
DecimalUtils.CastFloatToDecimalResult output5 = DecimalUtils.floatingPointToDecimal(input5, DType.create(DType.DTypeEnum.DECIMAL64, -1), 4);

try {
assert (!output1.hasFailure);
assert (!output2.hasFailure);
assert (!output3.hasFailure);
assert (output4.hasFailure);
assert (output5.hasFailure);

assertColumnsAreEqual(expected1, output1.result);
assertColumnsAreEqual(expected2, output2.result);
assertColumnsAreEqual(expected3, output3.result);
assertColumnsAreEqual(expected4, output4.result);
assertColumnsAreEqual(expected5, output5.result);
} finally {
output1.result.close();
output2.result.close();
output3.result.close();
output4.result.close();
output5.result.close();
}
}
}
Expand Down

0 comments on commit 30b80e0

Please sign in to comment.