diff --git a/CHANGELOG.md b/CHANGELOG.md index 78071b58adf..8c162a5de28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,7 +53,6 @@ - PR #6622 Update `to_pandas` api docs - PR #6623 Add operator overloading to column and clean up error messages - PR #6635 Add cudf::test::dictionary_column_wrapper class -- PR #6609 Support fixed-point decimal for HostColumnVector ## Bug Fixes diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 1930d36cef9..4c5739b5f3b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -24,8 +24,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.math.BigDecimal; -import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -3777,41 +3775,6 @@ public static ColumnVector timestampNanoSecondsFromLongs(long... values) { return build(DType.TIMESTAMP_NANOSECONDS, values.length, (b) -> b.appendArray(values)); } - /** - * Create a new decimal vector from unscaled values (int array) and scale. - * The created vector is of type DType.DECIMAL32, whose max precision is 9. - * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. - */ - public static ColumnVector decimalFromInts(int scale, int... values) { - try (HostColumnVector host = HostColumnVector.decimalFromInts(scale, values)) { - return host.copyToDevice(); - } - } - - /** - * Create a new decimal vector from unscaled values (long array) and scale. - * The created vector is of type DType.DECIMAL64, whose max precision is 18. - * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. - */ - public static ColumnVector decimalFromLongs(int scale, long... values) { - try (HostColumnVector host = HostColumnVector.decimalFromLongs(scale, values)) { - return host.copyToDevice(); - } - } - - /** - * Create a new decimal vector from double floats with specific DecimalType and RoundingMode. - * All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode. - * If any overflow occurs in extracting integral part, an IllegalArgumentException will be thrown. - * This API is inefficient because of slow double -> decimal conversion, so it is mainly for testing. - * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. - */ - public static ColumnVector decimalFromDoubles(DType type, RoundingMode mode, double... values) { - try (HostColumnVector host = HostColumnVector.decimalFromDoubles(type, mode, values)) { - return host.copyToDevice(); - } - } - /** * Create a new string vector from the given values. This API * supports inline nulls. This is really intended to be used only for testing as @@ -3823,19 +3786,6 @@ public static ColumnVector fromStrings(String... values) { } } - /** - * Create a new vector from the given values. This API supports inline nulls, - * but is much slower than building from primitive array of unscaledValues. - * Notice: - * 1. All input BigDecimals should share same scale. - * 2. The scale will be zero if all input values are null. - */ - public static ColumnVector fromDecimals(BigDecimal... values) { - try (HostColumnVector hcv = HostColumnVector.fromDecimals(values)) { - return hcv.copyToDevice(); - } - } - /** * Create a new vector from the given values. This API supports inline nulls, * but is much slower than using a regular array and should really only be used diff --git a/java/src/main/java/ai/rapids/cudf/DType.java b/java/src/main/java/ai/rapids/cudf/DType.java index 58fd47d77d6..9d32a7c40ec 100644 --- a/java/src/main/java/ai/rapids/cudf/DType.java +++ b/java/src/main/java/ai/rapids/cudf/DType.java @@ -21,8 +21,8 @@ public final class DType { - public static final int DECIMAL32_MAX_PRECISION = 9; - public static final int DECIMAL64_MAX_PRECISION = 18; + public static final int DECIMAL32_MAX_PRECISION = 10; + public static final int DECIMAL64_MAX_PRECISION = 19; /* enum representing various types. Whenever a new non-decimal type is added please make sure below sections are updated as well: @@ -92,8 +92,6 @@ public enum DTypeEnum { } public int getNativeId() { return nativeId; } - - public boolean isDecimalType() { return DType.DECIMALS.contains(this); } } final DTypeEnum typeId; @@ -233,7 +231,7 @@ public String toString() { * @return DType */ public static DType create(DTypeEnum dt) { - if (DType.DECIMALS.contains(dt)) { + if (dt == DTypeEnum.DECIMAL32 || dt == DTypeEnum.DECIMAL64) { throw new IllegalArgumentException("Could not create a Decimal DType without scale"); } return DType.fromNative(dt.nativeId, 0); @@ -249,7 +247,7 @@ public static DType create(DTypeEnum dt) { * @return DType */ public static DType create(DTypeEnum dt, int scale) { - if (!DType.DECIMALS.contains(dt)) { + if (dt != DTypeEnum.DECIMAL32 && dt != DTypeEnum.DECIMAL64) { throw new IllegalArgumentException("Could not create a non-Decimal DType with scale"); } return DType.fromNative(dt.nativeId, scale); @@ -321,8 +319,7 @@ public boolean hasTimeResolution() { * DType.INT32, * DType.UINT32, * DType.DURATION_DAYS, - * DType.TIMESTAMP_DAYS, - * DType.DECIMAL32 + * DType.TIMESTAMP_DAYS */ public boolean isBackedByInt() { return INTS.contains(this.typeId); @@ -340,8 +337,7 @@ public boolean isBackedByInt() { * DType.TIMESTAMP_SECONDS, * DType.TIMESTAMP_MILLISECONDS, * DType.TIMESTAMP_MICROSECONDS, - * DType.TIMESTAMP_NANOSECONDS, - * DType.DECIMAL64 + * DType.TIMESTAMP_NANOSECONDS */ public boolean isBackedByLong() { return LONGS.contains(this.typeId); @@ -370,7 +366,7 @@ public boolean isBackedByLong() { * DType.DECIMAL32, * DType.DECIMAL64 */ - public boolean isDecimalType() { return this.typeId.isDecimalType(); } + public boolean isDecimalType() { return DECIMALS.contains(this.typeId); } /** * Returns true for duration types @@ -426,18 +422,14 @@ public boolean isTimestampType() { DTypeEnum.TIMESTAMP_SECONDS, DTypeEnum.TIMESTAMP_MILLISECONDS, DTypeEnum.TIMESTAMP_MICROSECONDS, - DTypeEnum.TIMESTAMP_NANOSECONDS, - // The unscaledValue of DECIMAL64 is of type INT64, which means it can be fetched by getLong. - DTypeEnum.DECIMAL64 + DTypeEnum.TIMESTAMP_NANOSECONDS ); private static final EnumSet INTS = EnumSet.of( DTypeEnum.INT32, DTypeEnum.UINT32, DTypeEnum.DURATION_DAYS, - DTypeEnum.TIMESTAMP_DAYS, - // The unscaledValue of DECIMAL32 is of type INT32, which means it can be fetched by getInt. - DTypeEnum.DECIMAL32 + DTypeEnum.TIMESTAMP_DAYS ); private static final EnumSet SHORTS = EnumSet.of( diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 2f32c9af565..48e7e60df06 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -18,15 +18,9 @@ package ai.rapids.cudf; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Comparator; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.StringJoiner; import java.util.function.Consumer; @@ -447,50 +441,6 @@ public static HostColumnVector timestampNanoSecondsFromLongs(long... values) { return build(DType.TIMESTAMP_NANOSECONDS, values.length, (b) -> b.appendArray(values)); } - /** - * Create a new decimal vector from unscaled values (int array) and scale. - * The created vector is of type DType.DECIMAL32, whose max precision is 9. - * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. - */ - public static HostColumnVector decimalFromInts(int scale, int... values) { - return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values)); - } - - /** - * Create a new decimal vector from unscaled values (long array) and scale. - * The created vector is of type DType.DECIMAL64, whose max precision is 18. - * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. - */ - public static HostColumnVector decimalFromLongs(int scale, long... values) { - return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values)); - } - - /** - * Create a new decimal vector from double floats with specific DecimalType and RoundingMode. - * All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode. - * If any overflow occurs in extracting integral part, an IllegalArgumentException will be thrown. - * This API is inefficient because of slow double -> decimal conversion, so it is mainly for testing. - * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. - */ - public static HostColumnVector decimalFromDoubles(DType type, RoundingMode mode, double... values) { - assert type.isDecimalType(); - if (type.typeId == DType.DTypeEnum.DECIMAL64) { - long[] data = new long[values.length]; - for (int i = 0; i < values.length; i++) { - BigDecimal dec = BigDecimal.valueOf(values[i]).setScale(-type.getScale(), mode); - data[i] = dec.unscaledValue().longValueExact(); - } - return build(type, values.length, (b) -> b.appendUnscaledDecimalArray(data)); - } else { - int[] data = new int[values.length]; - for (int i = 0; i < values.length; i++) { - BigDecimal dec = BigDecimal.valueOf(values[i]).setScale(-type.getScale(), mode); - data[i] = dec.unscaledValue().intValueExact(); - } - return build(type, values.length, (b) -> b.appendUnscaledDecimalArray(data)); - } - } - /** * Create a new string vector from the given values. This API * supports inline nulls. This is really intended to be used only for testing as @@ -518,30 +468,6 @@ public static HostColumnVector fromStrings(String... values) { }); } - /** - * Create a new vector from the given values. This API supports inline nulls, - * but is much slower than building from primitive array of unscaledValues. - * Notice: - * 1. Input values will be rescaled with min scale (max scale in terms of java.math.BigDecimal), - * which avoids potential precision loss due to rounding. But there exists risk of precision overflow. - * 2. The scale will be zero if all input values are null. - */ - public static HostColumnVector fromDecimals(BigDecimal... values) { - // 1. Fetch the element with max precision (maxDec). Fill with ZERO if inputs is empty. - // 2. Fetch the max scale. Fill with ZERO if inputs is empty. - // 3. Rescale the maxDec with the max scale, so to come out the max precision capacity we need. - BigDecimal maxDec = Arrays.stream(values).filter(Objects::nonNull) - .max(Comparator.comparingInt(BigDecimal::precision)) - .orElse(BigDecimal.ZERO); - int maxScale = Arrays.stream(values) - .map(decimal -> (decimal == null) ? 0 : decimal.scale()) - .max(Comparator.naturalOrder()) - .orElse(0); - maxDec = maxDec.setScale(maxScale, RoundingMode.UNNECESSARY); - - return build(DType.fromJavaBigDecimal(maxDec), values.length, (b) -> b.appendBoxed(values)); - } - /** * Create a new vector from the given values. This API supports inline nulls, * but is much slower than using a regular array and should really only be used @@ -977,8 +903,6 @@ private void appendChildOrNull(ColumnBuilder childBuilder, Object listElement) { childBuilder.append((Byte) listElement); } else if (listElement instanceof Short) { childBuilder.append((Short) listElement); - } else if (listElement instanceof BigDecimal) { - childBuilder.append((BigDecimal) listElement); } else if (listElement instanceof List) { childBuilder.append((List) listElement); } else if (listElement instanceof StructData) { @@ -1068,23 +992,6 @@ public final ColumnBuilder append(boolean value) { return this; } - public final ColumnBuilder append(BigDecimal value) { - growBuffersAndRows(false, currentIndex * type.getSizeInBytes() + type.getSizeInBytes()); - assert currentIndex < rows; - // Rescale input decimal with UNNECESSARY policy, which accepts no precision loss. - BigInteger unscaledVal = value.setScale(-type.getScale(), RoundingMode.UNNECESSARY).unscaledValue(); - if (type.typeId == DType.DTypeEnum.DECIMAL32) { - data.setInt(currentIndex * type.getSizeInBytes(), unscaledVal.intValueExact()); - } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { - data.setLong(currentIndex * type.getSizeInBytes(), unscaledVal.longValueExact()); - } else { - throw new IllegalStateException(type + " is not a supported decimal type."); - } - currentIndex++; - currentByteIndex += type.getSizeInBytes(); - return this; - } - public ColumnBuilder append(String value) { assert value != null : "appendNull must be used to append null strings"; return appendUTF8String(value.getBytes(StandardCharsets.UTF_8)); @@ -1280,57 +1187,6 @@ public final Builder append(double value) { return this; } - /** - * Append java.math.BigDecimal into HostColumnVector with UNNECESSARY RoundingMode. - * Input decimal should have a larger scale than column vector.Otherwise, an ArithmeticException will be thrown while rescaling. - * If unscaledValue after rescaling exceeds the max precision of rapids type, - * an ArithmeticException will be thrown while extracting integral. - * - * @param value BigDecimal value to be appended - */ - public final Builder append(BigDecimal value) { - return append(value, RoundingMode.UNNECESSARY); - } - - /** - * Append java.math.BigDecimal into HostColumnVector with user-defined RoundingMode. - * Input decimal will be rescaled according to scale of column type and RoundingMode before appended. - * If unscaledValue after rescaling exceeds the max precision of rapids type, an ArithmeticException will be thrown. - * - * @param value BigDecimal value to be appended - * @param roundingMode rounding mode determines rescaling behavior - */ - public final Builder append(BigDecimal value, RoundingMode roundingMode) { - assert type.isDecimalType(); - assert currentIndex < rows; - BigInteger unscaledValue = value.setScale(-type.getScale(), roundingMode).unscaledValue(); - if (type.typeId == DType.DTypeEnum.DECIMAL32) { - data.setInt(currentIndex * type.getSizeInBytes(), unscaledValue.intValueExact()); - } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { - data.setLong(currentIndex * type.getSizeInBytes(), unscaledValue.longValueExact()); - } else { - throw new IllegalStateException(type + " is not a supported decimal type."); - } - currentIndex++; - return this; - } - - public final Builder appendUnscaledDecimal(int value) { - assert type.typeId == DType.DTypeEnum.DECIMAL32; - assert currentIndex < rows; - data.setInt(currentIndex * type.getSizeInBytes(), value); - currentIndex++; - return this; - } - - public final Builder appendUnscaledDecimal(long value) { - assert type.typeId == DType.DTypeEnum.DECIMAL64; - assert currentIndex < rows; - data.setLong(currentIndex * type.getSizeInBytes(), value); - currentIndex++; - return this; - } - public Builder append(String value) { assert value != null : "appendNull must be used to append null strings"; return appendUTF8String(value.getBytes(StandardCharsets.UTF_8)); @@ -1427,40 +1283,6 @@ public Builder appendArray(double... values) { return this; } - public Builder appendUnscaledDecimalArray(int... values) { - assert type.typeId == DType.DTypeEnum.DECIMAL32; - assert (values.length + currentIndex) <= rows; - data.setInts(currentIndex * type.getSizeInBytes(), values, 0, values.length); - currentIndex += values.length; - return this; - } - - public Builder appendUnscaledDecimalArray(long... values) { - assert type.typeId == DType.DTypeEnum.DECIMAL64; - assert (values.length + currentIndex) <= rows; - data.setLongs(currentIndex * type.getSizeInBytes(), values, 0, values.length); - currentIndex += values.length; - return this; - } - - /** - * Append multiple values. This is very slow and should really only be used for tests. - * @param values the values to append, including nulls. - * @return this for chaining. - * @throws {@link IndexOutOfBoundsException} - */ - public Builder appendBoxed(BigDecimal... values) throws IndexOutOfBoundsException { - assert type.isDecimalType(); - for (BigDecimal v : values) { - if (v == null) { - appendNull(); - } else { - append(v); - } - } - return this; - } - /** * Append multiple values. This is very slow and should really only be used for tests. * @param values the values to append, including nulls. diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java b/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java index f470095602d..6b4952d8e4a 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVectorCore.java @@ -21,7 +21,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -372,23 +371,6 @@ public final boolean getBoolean(long index) { return offHeap.data.getBoolean(index * type.getSizeInBytes()); } - /** - * Get the BigDecimal value at index. - */ - public final BigDecimal getBigDecimal(long index) { - assert type.isDecimalType() : type + " is not a supported decimal type."; - assertsForGet(index); - if (type.typeId == DType.DTypeEnum.DECIMAL32) { - int unscaledValue = offHeap.data.getInt(index * type.getSizeInBytes()); - return BigDecimal.valueOf(unscaledValue, -type.getScale()); - } else if (type.typeId == DType.DTypeEnum.DECIMAL64) { - long unscaledValue = offHeap.data.getLong(index * type.getSizeInBytes()); - return BigDecimal.valueOf(unscaledValue, -type.getScale()); - } else { - throw new IllegalStateException(type + " is not a supported decimal type."); - } - } - /** * Get the raw UTF8 bytes at index. This API is faster than getJavaString, but still not * ideal because it is copying the data onto the heap. @@ -536,8 +518,6 @@ private Object readValue(int rowIndex) { case INT16: return offHeap.data.getShort(rowOffset); case BOOL8: return offHeap.data.getBoolean(rowOffset); case STRING: return getString(rowIndex); - case DECIMAL32: return BigDecimal.valueOf(offHeap.data.getInt(rowOffset), -type.getScale()); - case DECIMAL64: return BigDecimal.valueOf(offHeap.data.getLong(rowOffset), -type.getScale()); default: throw new UnsupportedOperationException("Do not support " + type); } } diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 0c85a6e24ac..6c9ca6a3282 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -215,16 +215,6 @@ public static Scalar fromFloat(float value) { return new Scalar(DType.FLOAT32, makeFloat32Scalar(value, true)); } - public static Scalar fromDecimal(int scale, int unscaledValue) { - long handle = makeDecimal32Scalar(unscaledValue, scale, true); - return new Scalar(DType.create(DType.DTypeEnum.DECIMAL32, scale), handle); - } - - public static Scalar fromDecimal(int scale, long unscaledValue) { - long handle = makeDecimal64Scalar(unscaledValue, scale, true); - return new Scalar(DType.create(DType.DTypeEnum.DECIMAL64, scale), handle); - } - public static Scalar fromFloat(Float value) { if (value == null) { return Scalar.fromNull(DType.FLOAT32); @@ -243,7 +233,7 @@ public static Scalar fromDouble(Double value) { return Scalar.fromDouble(value.doubleValue()); } - public static Scalar fromDecimal(BigDecimal value) { + public static Scalar fromBigDecimal(BigDecimal value) { if (value == null) { return Scalar.fromNull(DType.create(DType.DTypeEnum.DECIMAL64, 0)); } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 572e9a1a868..09df2279f8f 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -21,13 +21,10 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import java.math.BigDecimal; -import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Supplier; -import java.util.stream.Collectors; import java.util.stream.IntStream; import static ai.rapids.cudf.QuantileMethod.HIGHER; @@ -38,7 +35,11 @@ import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; import static ai.rapids.cudf.TableTest.assertTablesAreEqual; import static ai.rapids.cudf.TableTest.assertStructColumnsAreEqual; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; public class ColumnVectorTest extends CudfTestBase { @@ -702,9 +703,11 @@ void testSequenceOtherTypes() { @Test void testFromScalarZeroRows() { - // magic number to invoke factory method specialized for decimal types - int mockScale = -8; for (DType.DTypeEnum type : DType.DTypeEnum.values()) { + // Decimal type not supported yet. Update this once it is supported. + if (type == DType.DTypeEnum.DECIMAL32 || type == DType.DTypeEnum.DECIMAL64) { + continue; + } Scalar s = null; try { switch (type) { @@ -741,12 +744,6 @@ void testFromScalarZeroRows() { case FLOAT64: s = Scalar.fromDouble(1.23456789); break; - case DECIMAL32: - s = Scalar.fromDecimal(mockScale, 123456789); - break; - case DECIMAL64: - s = Scalar.fromDecimal(mockScale, 1234567890123456789L); - break; case TIMESTAMP_DAYS: s = Scalar.timestampDaysFromInt(12345); break; @@ -777,11 +774,7 @@ void testFromScalarZeroRows() { } try (ColumnVector c = ColumnVector.fromScalar(s, 0)) { - if (type.isDecimalType()) { - assertEquals(DType.create(type, mockScale), c.getType()); - } else { - assertEquals(DType.create(type), c.getType()); - } + assertEquals(DType.create(type), c.getType()); assertEquals(0, c.getRowCount()); assertEquals(0, c.getNullCount()); } @@ -805,7 +798,7 @@ void testGetNativeView() { void testFromScalar() { final int rowCount = 4; for (DType.DTypeEnum type : DType.DTypeEnum.values()) { - if(type.isDecimalType()) { + if(type == DType.DTypeEnum.DECIMAL32 || type == DType.DTypeEnum.DECIMAL64) { continue; } Scalar s = null; @@ -971,17 +964,11 @@ void testFromScalar() { void testFromScalarNull() { final int rowCount = 4; for (DType.DTypeEnum type : DType.DTypeEnum.values()) { - if (type == DType.DTypeEnum.EMPTY || type == DType.DTypeEnum.LIST || type == DType.DTypeEnum.STRUCT) { + if (type == DType.DTypeEnum.EMPTY || type == DType.DTypeEnum.LIST || type == DType.DTypeEnum.STRUCT + || type == DType.DTypeEnum.DECIMAL32 || type == DType.DTypeEnum.DECIMAL64) { continue; } - DType dType; - if (type.isDecimalType()) { - // magic number to invoke factory method specialized for decimal types - dType = DType.create(type, -8); - } else { - dType = DType.create(type); - } - try (Scalar s = Scalar.fromNull(dType); + try (Scalar s = Scalar.fromNull(DType.create(type)); ColumnVector c = ColumnVector.fromScalar(s, rowCount); HostColumnVector hc = c.copyToHost()) { assertEquals(type, c.getType().typeId); @@ -2974,28 +2961,6 @@ void testListOfListsCvDoubles() { } } - @Test - void testListOfListsCvDecimals() { - List list1 = Arrays.asList(BigDecimal.valueOf(1.1), BigDecimal.valueOf(2.2), BigDecimal.valueOf(3.3)); - List list2 = Arrays.asList(BigDecimal.valueOf(4.4), BigDecimal.valueOf(5.5), BigDecimal.valueOf(6.6)); - List list3 = Arrays.asList(BigDecimal.valueOf(10.1), BigDecimal.valueOf(20.2), BigDecimal.valueOf(30.3)); - List> mainList1 = new ArrayList<>(); - mainList1.add(list1); - mainList1.add(list2); - List> mainList2 = new ArrayList<>(); - mainList2.add(list3); - - HostColumnVector.BasicType basicType = new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, -1)); - try(ColumnVector res = ColumnVector.fromLists(new HostColumnVector.ListType(true, - new HostColumnVector.ListType(true, basicType)), mainList1, mainList2); - HostColumnVector hcv = res.copyToHost()) { - List> ret1 = hcv.getList(0); - List> ret2 = hcv.getList(1); - assertEquals(mainList1, ret1, "Lists don't match"); - assertEquals(mainList2, ret2, "Lists don't match"); - } - } - @Test void testConcatLists() { List list1 = Arrays.asList(0, 1, 2, 3); @@ -3095,32 +3060,6 @@ void testHcvOfInts() { } } - @Test - void testHcvOfDecimals() { - List[] data = new List[6]; - data[0] = Arrays.asList(BigDecimal.ONE, BigDecimal.TEN); - data[1] = Arrays.asList(BigDecimal.ZERO); - data[2] = null; - data[3] = Arrays.asList(); - data[4] = Arrays.asList(BigDecimal.valueOf(123), BigDecimal.valueOf(1, -2)); - data[5] = Arrays.asList(BigDecimal.valueOf(100, -3), BigDecimal.valueOf(2, -4)); - try(ColumnVector expected = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, 0))), data); - HostColumnVector hcv = expected.copyToHost()) { - for (int i = 0; i < data.length; i++) { - if (data[i] == null) { - assertNull(hcv.getList(i)); - continue; - } - List exp = data[i].stream() - .map((dec -> (dec == null) ? null : dec.setScale(0, RoundingMode.UNNECESSARY))) - .collect(Collectors.toList()); - assertEquals(exp, hcv.getList(i)); - } - } - } - @Test void testConcatListsOfLists() { List list1 = Arrays.asList(1, 2, 3); diff --git a/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java deleted file mode 100644 index 8703981786a..00000000000 --- a/java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package ai.rapids.cudf; - -import ai.rapids.cudf.HostColumnVector.Builder; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.util.Arrays; -import java.util.Objects; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.*; - -public class DecimalColumnVectorTest extends CudfTestBase { - private static final Random rdSeed = new Random(1234); - private static final int dec32Scale = 4; - private static final int dec64Scale = 10; - - private static final BigDecimal[] decimal32Zoo = new BigDecimal[20]; - private static final BigDecimal[] decimal64Zoo = new BigDecimal[20]; - private static final int[] unscaledDec32Zoo = new int[decimal32Zoo.length]; - private static final long[] unscaledDec64Zoo = new long[decimal64Zoo.length]; - - private final BigDecimal[] boundaryDecimal32 = new BigDecimal[]{ - new BigDecimal("999999999"), new BigDecimal("-999999999")}; - - private final BigDecimal[] boundaryDecimal64 = new BigDecimal[]{ - new BigDecimal("999999999999999999"), new BigDecimal("-999999999999999999")}; - - private final BigDecimal[] overflowDecimal32 = new BigDecimal[]{ - BigDecimal.valueOf(Integer.MAX_VALUE), BigDecimal.valueOf(Integer.MIN_VALUE)}; - - private final BigDecimal[] overflowDecimal64 = new BigDecimal[]{ - BigDecimal.valueOf(Long.MAX_VALUE), BigDecimal.valueOf(Long.MIN_VALUE)}; - - @BeforeAll - public static void setup() { - for (int i = 0; i < decimal32Zoo.length; i++) { - unscaledDec32Zoo[i] = rdSeed.nextInt() / 100; - unscaledDec64Zoo[i] = rdSeed.nextLong() / 100; - if (rdSeed.nextBoolean()) { - // Create BigDecimal with slight variance on scale, in order to test building cv from inputs with different scales. - decimal32Zoo[i] = BigDecimal.valueOf(rdSeed.nextInt() / 100, dec32Scale - rdSeed.nextInt(2)); - } else { - decimal32Zoo[i] = null; - } - if (rdSeed.nextBoolean()) { - // Create BigDecimal with slight variance on scale, in order to test building cv from inputs with different scales. - decimal64Zoo[i] = BigDecimal.valueOf(rdSeed.nextLong() / 100, dec64Scale - rdSeed.nextInt(2)); - } else { - decimal64Zoo[i] = null; - } - } - } - - @Test - public void testCreateColumnVectorBuilder() { - try (ColumnVector cv = ColumnVector.build(DType.create(DType.DTypeEnum.DECIMAL32, -5), 3, - (b) -> b.append(BigDecimal.valueOf(123456789, 5)))) { - assertFalse(cv.hasNulls()); - } - try (ColumnVector cv = ColumnVector.build(DType.create(DType.DTypeEnum.DECIMAL64, -10), 3, - (b) -> b.append(BigDecimal.valueOf(1023040506070809L, 10)))) { - assertFalse(cv.hasNulls()); - } - // test building ColumnVector from BigDecimal values with varying scales - try (ColumnVector cv = ColumnVector.build(DType.create(DType.DTypeEnum.DECIMAL64, -5), 7, - (b) -> b.append(BigDecimal.valueOf(123456, 0), RoundingMode.UNNECESSARY) - .append(BigDecimal.valueOf(123456, 2), RoundingMode.UNNECESSARY) - .append(BigDecimal.valueOf(123456, 5)) - .append(BigDecimal.valueOf(123456, 7), RoundingMode.HALF_UP) - .append(BigDecimal.valueOf(123456, 7), RoundingMode.FLOOR) - .append(BigDecimal.valueOf(123456, 9), RoundingMode.HALF_DOWN) - .append(BigDecimal.valueOf(123456, 9), RoundingMode.CEILING))) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertEquals(12345600000L, hcv.getLong(0)); - assertEquals(123456000L, hcv.getLong(1)); - assertEquals(123456L, hcv.getLong(2)); - assertEquals(1235L, hcv.getLong(3)); - assertEquals(1234L, hcv.getLong(4)); - assertEquals(12L, hcv.getLong(5)); - assertEquals(13L, hcv.getLong(6)); - } - } - } - - @Test - public void testUpperIndexOutOfBoundsException() { - try (HostColumnVector decColumnVector = HostColumnVector.fromDecimals(decimal32Zoo)) { - assertThrows(AssertionError.class, () -> decColumnVector.getBigDecimal(decimal32Zoo.length)); - } - } - - @Test - public void testLowerIndexOutOfBoundsException() { - try (HostColumnVector doubleColumnVector = HostColumnVector.fromDecimals(decimal32Zoo)) { - assertThrows(AssertionError.class, () -> doubleColumnVector.getBigDecimal(-1)); - } - } - - @Test - public void testAddingNullValues() { - try (HostColumnVector cv = HostColumnVector.fromDecimals(decimal64Zoo)) { - for (int i = 0; i < decimal64Zoo.length; ++i) { - assertEquals(decimal64Zoo[i] == null, cv.isNull(i)); - } - assertEquals(Arrays.stream(decimal64Zoo).filter(Objects::isNull).count(), cv.getNullCount()); - } - } - - @Test - public void testOverrunningTheBuffer() { - try (Builder builder = HostColumnVector.builder(DType.create(DType.DTypeEnum.DECIMAL32, -dec32Scale), 3)) { - assertThrows(AssertionError.class, () -> builder.appendBoxed(decimal32Zoo).build()); - } - try (Builder builder = HostColumnVector.builder(DType.create(DType.DTypeEnum.DECIMAL64, -dec64Scale), 3)) { - assertThrows(AssertionError.class, () -> builder.appendUnscaledDecimalArray(unscaledDec64Zoo).build()); - } - } - - @Test - public void testDecimalValidation() { - // precision overflow - assertThrows(IllegalArgumentException.class, () -> HostColumnVector.fromDecimals(overflowDecimal64)); - assertThrows(IllegalArgumentException.class, () -> { - ColumnVector.decimalFromInts(-(DType.DECIMAL32_MAX_PRECISION + 1), unscaledDec32Zoo); - }); - assertThrows(IllegalArgumentException.class, () -> { - ColumnVector.decimalFromLongs(-(DType.DECIMAL64_MAX_PRECISION + 1), unscaledDec64Zoo); - }); - // precision overflow due to rescaling by min scale - assertThrows(IllegalArgumentException.class, () -> { - ColumnVector.fromDecimals(BigDecimal.valueOf(1.23e10), BigDecimal.valueOf(1.2e-7)); - }); - // exactly hit the MAX_PRECISION_DECIMAL64 after rescaling - assertDoesNotThrow(() -> { - ColumnVector.fromDecimals(BigDecimal.valueOf(1.23e10), BigDecimal.valueOf(1.2e-6)); - }); - } - - @Test - public void testDecimalGeneral() { - // Safe max precision of Decimal32 is 9, so integers have 10 digits will be backed by DECIMAL64. - try (ColumnVector cv = ColumnVector.fromDecimals(overflowDecimal32)) { - assertEquals(DType.create(DType.DTypeEnum.DECIMAL64, 0), cv.getDataType()); - } - // Create DECIMAL64 vector with small values - try (ColumnVector cv = ColumnVector.decimalFromLongs(0, 0L)) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertTrue(hcv.getType().isBackedByLong()); - assertEquals(0L, hcv.getBigDecimal(0).longValue()); - } - } - } - - @Test - public void testDecimalFromDecimals() { - DecimalColumnVectorTest.testDecimalImpl(false, dec32Scale, decimal32Zoo); - DecimalColumnVectorTest.testDecimalImpl(true, dec64Scale, decimal64Zoo); - DecimalColumnVectorTest.testDecimalImpl(false, 0, boundaryDecimal32); - DecimalColumnVectorTest.testDecimalImpl(true, 0, boundaryDecimal64); - } - - private static void testDecimalImpl(boolean isInt64, int scale, BigDecimal[] decimalZoo) { - try (ColumnVector cv = ColumnVector.fromDecimals(decimalZoo)) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertEquals(-scale, hcv.getType().getScale()); - assertEquals(isInt64, hcv.getType().typeId == DType.DTypeEnum.DECIMAL64); - assertEquals(decimalZoo.length, hcv.rows); - for (int i = 0; i < decimalZoo.length; i++) { - assertEquals(decimalZoo[i] == null, hcv.isNull(i)); - if (decimalZoo[i] != null) { - assertEquals(decimalZoo[i].floatValue(), hcv.getBigDecimal(i).floatValue()); - long backValue = isInt64 ? hcv.getLong(i) : hcv.getInt(i); - assertEquals(decimalZoo[i].setScale(scale, RoundingMode.UNNECESSARY), BigDecimal.valueOf(backValue, scale)); - } - } - } - } - } - - @Test - private void testDecimalFromInts() { - try (ColumnVector cv = ColumnVector.decimalFromInts(-DecimalColumnVectorTest.dec32Scale, DecimalColumnVectorTest.unscaledDec32Zoo)) { - try (HostColumnVector hcv = cv.copyToHost()) { - for (int i = 0; i < DecimalColumnVectorTest.unscaledDec32Zoo.length; i++) { - assertEquals(DecimalColumnVectorTest.unscaledDec32Zoo[i], hcv.getInt(i)); - assertEquals(BigDecimal.valueOf(DecimalColumnVectorTest.unscaledDec32Zoo[i], DecimalColumnVectorTest.dec32Scale), hcv.getBigDecimal(i)); - } - } - } - } - - @Test - private static void testDecimalFromLongs() { - try (ColumnVector cv = ColumnVector.decimalFromLongs(-DecimalColumnVectorTest.dec64Scale, DecimalColumnVectorTest.unscaledDec64Zoo)) { - try (HostColumnVector hcv = cv.copyToHost()) { - for (int i = 0; i < DecimalColumnVectorTest.unscaledDec64Zoo.length; i++) { - assertEquals(DecimalColumnVectorTest.unscaledDec64Zoo[i], hcv.getLong(i)); - assertEquals(BigDecimal.valueOf(DecimalColumnVectorTest.unscaledDec64Zoo[i], DecimalColumnVectorTest.dec64Scale), hcv.getBigDecimal(i)); - } - } - } - } - - @Test - public void testDecimalFromDoubles() { - DType dt = DType.create(DType.DTypeEnum.DECIMAL32, -3); - try (ColumnVector cv = ColumnVector.decimalFromDoubles(dt, RoundingMode.DOWN,123456, -2.4567, 3.00001, -1111e-5)) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertEquals(123456, hcv.getBigDecimal(0).doubleValue()); - assertEquals(-2.456, hcv.getBigDecimal(1).doubleValue()); - assertEquals(3, hcv.getBigDecimal(2).doubleValue()); - assertEquals(-0.011, hcv.getBigDecimal(3).doubleValue()); - } - } - dt = DType.create(DType.DTypeEnum.DECIMAL64, -10); - try (ColumnVector cv = ColumnVector.decimalFromDoubles(dt, RoundingMode.HALF_UP, 1.2345678, -2.45e-9, 3.000012, -51111e-15)) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertEquals(1.2345678, hcv.getBigDecimal(0).doubleValue()); - assertEquals(-2.5e-9, hcv.getBigDecimal(1).doubleValue()); - assertEquals(3.000012, hcv.getBigDecimal(2).doubleValue()); - assertEquals(-1e-10, hcv.getBigDecimal(3).doubleValue()); - } - } - dt = DType.create(DType.DTypeEnum.DECIMAL64, 10); - try (ColumnVector cv = ColumnVector.decimalFromDoubles(dt, RoundingMode.UP, 1.234e20, -12.34e8, 1.1e10)) { - try (HostColumnVector hcv = cv.copyToHost()) { - assertEquals(1.234e20, hcv.getBigDecimal(0).doubleValue()); - assertEquals(-1e10, hcv.getBigDecimal(1).doubleValue()); - assertEquals(2e10, hcv.getBigDecimal(2).doubleValue()); - } - } - assertThrows(ArithmeticException.class, - () -> { - final DType dt1 = DType.create(DType.DTypeEnum.DECIMAL32, -5); - try (ColumnVector cv = ColumnVector.decimalFromDoubles(dt1, RoundingMode.UNNECESSARY, 30000)) { - } - }); - assertThrows(ArithmeticException.class, - () -> { - final DType dt1 = DType.create(DType.DTypeEnum.DECIMAL64, 10); - try (ColumnVector cv = ColumnVector.decimalFromDoubles(dt1, RoundingMode.FLOOR, 1e100)) { - } - }); - } - - @Test - public void testAppendVector() { - for (DType decType : new DType[]{ - DType.create(DType.DTypeEnum.DECIMAL32, -6), - DType.create(DType.DTypeEnum.DECIMAL64, -10)}) { - for (int dstSize = 1; dstSize <= 100; dstSize++) { - for (int dstPrefilledSize = 0; dstPrefilledSize < dstSize; dstPrefilledSize++) { - final int srcSize = dstSize - dstPrefilledSize; - for (int sizeOfDataNotToAdd = 0; sizeOfDataNotToAdd <= dstPrefilledSize; sizeOfDataNotToAdd++) { - try (Builder dst = HostColumnVector.builder(decType, dstSize); - HostColumnVector src = HostColumnVector.build(decType, srcSize, (b) -> { - for (int i = 0; i < srcSize; i++) { - if (rdSeed.nextBoolean()) { - b.appendNull(); - } else { - b.append(BigDecimal.valueOf(rdSeed.nextInt() / 100, -decType.getScale())); - } - } - }); - Builder gtBuilder = HostColumnVector.builder(decType, dstPrefilledSize)) { - assertEquals(dstSize, srcSize + dstPrefilledSize); - //add the first half of the prefilled list - for (int i = 0; i < dstPrefilledSize - sizeOfDataNotToAdd; i++) { - if (rdSeed.nextBoolean()) { - dst.appendNull(); - gtBuilder.appendNull(); - } else { - BigDecimal a = BigDecimal.valueOf(rdSeed.nextInt() / 100, -decType.getScale()); - if (decType.typeId == DType.DTypeEnum.DECIMAL32) { - dst.appendUnscaledDecimal(a.unscaledValue().intValueExact()); - } else { - dst.appendUnscaledDecimal(a.unscaledValue().longValueExact()); - } - gtBuilder.append(a); - } - } - // append the src vector - dst.append(src); - try (HostColumnVector dstVector = dst.build(); - HostColumnVector gt = gtBuilder.build()) { - for (int i = 0; i < dstPrefilledSize - sizeOfDataNotToAdd; i++) { - assertEquals(gt.isNull(i), dstVector.isNull(i)); - if (!gt.isNull(i)) { - assertEquals(gt.getBigDecimal(i), dstVector.getBigDecimal(i)); - } - } - for (int i = dstPrefilledSize - sizeOfDataNotToAdd, j = 0; i < dstSize - sizeOfDataNotToAdd && j < srcSize; i++, j++) { - assertEquals(src.isNull(j), dstVector.isNull(i)); - if (!src.isNull(j)) { - assertEquals(src.getBigDecimal(j), dstVector.getBigDecimal(i)); - } - } - if (dstVector.hasValidityVector()) { - long maxIndex = - BitVectorHelper.getValidityAllocationSizeInBytes(dstVector.getRowCount()) * 8; - for (long i = dstSize - sizeOfDataNotToAdd; i < maxIndex; i++) { - assertFalse(dstVector.isNullExtendedRange(i)); - } - } - } - } - } - } - } - } - } -} diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index 627171e4b2f..47cad78ce5c 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import java.math.BigDecimal; +import java.math.MathContext; import static org.junit.jupiter.api.Assertions.*; @@ -49,7 +50,7 @@ public void testIncRef() { public void testNull() { for (DType.DTypeEnum dataType : DType.DTypeEnum.values()) { DType type; - if (dataType.isDecimalType()) { + if (dataType == DType.DTypeEnum.DECIMAL32 || dataType == DType.DTypeEnum.DECIMAL64) { type = DType.create(dataType, -3); } else { type = DType.create(dataType); @@ -133,20 +134,12 @@ public void testDecimal() { BigDecimal.valueOf(12345678, 2), BigDecimal.valueOf(1234567890123L, 6), }; - for (BigDecimal dec: bigDecimals) { - try (Scalar s = Scalar.fromDecimal(dec)) { - assertEquals(DType.fromJavaBigDecimal(dec), s.getType()); + for (BigDecimal bigDec: bigDecimals) { + try (Scalar s = Scalar.fromBigDecimal(bigDec)) { + assertEquals(DType.fromJavaBigDecimal(bigDec), s.getType()); assertTrue(s.isValid()); - assertEquals(dec.unscaledValue().longValueExact(), s.getLong()); - assertEquals(dec, s.getBigDecimal()); - } - try (Scalar s = Scalar.fromDecimal(-dec.scale(), dec.unscaledValue().intValueExact())) { - assertEquals(dec, s.getBigDecimal()); - } catch (java.lang.ArithmeticException ex) { - try (Scalar s = Scalar.fromDecimal(-dec.scale(), dec.unscaledValue().longValueExact())) { - assertEquals(dec, s.getBigDecimal()); - assertTrue(s.getType().isBackedByLong()); - } + assertEquals(bigDec.unscaledValue().longValueExact(), s.getLong()); + assertEquals(bigDec, s.getBigDecimal()); } } }