diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 9f414661967..defb6eea5b9 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1164,6 +1164,17 @@ public static ColumnVector decimalFromInts(int scale, int... values) { } } + /** + * Create a new decimal vector from boxed unscaled values (Integer array) and scale. + * The created vector is of type DType.DECIMAL32, whose max precision is 9. + * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. + */ + public static ColumnVector decimalFromBoxedInts(int scale, Integer... values) { + try (HostColumnVector host = HostColumnVector.decimalFromBoxedInts(scale, values)) { + return host.copyToDevice(); + } + } + /** * Create a new decimal vector from unscaled values (long array) and scale. * The created vector is of type DType.DECIMAL64, whose max precision is 18. @@ -1175,6 +1186,17 @@ public static ColumnVector decimalFromLongs(int scale, long... values) { } } + /** + * Create a new decimal vector from boxed unscaled values (Long array) and scale. + * The created vector is of type DType.DECIMAL64, whose max precision is 18. + * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. + */ + public static ColumnVector decimalFromBoxedLongs(int scale, Long... values) { + try (HostColumnVector host = HostColumnVector.decimalFromBoxedLongs(scale, values)) { + return host.copyToDevice(); + } + } + /** * Create a new decimal vector from double floats with specific DecimalType and RoundingMode. * All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode. diff --git a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java index 559256aa7bf..846bcb3b635 100644 --- a/java/src/main/java/ai/rapids/cudf/HostColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/HostColumnVector.java @@ -481,6 +481,23 @@ public static HostColumnVector decimalFromInts(int scale, int... values) { return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values)); } + /** + * Create a new decimal vector from boxed unscaled values (Integer array) and scale. + * The created vector is of type DType.DECIMAL32, whose max precision is 9. + * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. + */ + public static HostColumnVector decimalFromBoxedInts(int scale, Integer... values) { + return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> { + for (Integer v : values) { + if (v == null) { + b.appendNull(); + } else { + b.appendUnscaledDecimal(v); + } + } + }); + } + /** * Create a new decimal vector from unscaled values (long array) and scale. * The created vector is of type DType.DECIMAL64, whose max precision is 18. @@ -490,6 +507,23 @@ public static HostColumnVector decimalFromLongs(int scale, long... values) { return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values)); } + /** + * Create a new decimal vector from boxed unscaled values (Long array) and scale. + * The created vector is of type DType.DECIMAL64, whose max precision is 18. + * Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning. + */ + public static HostColumnVector decimalFromBoxedLongs(int scale, Long... values) { + return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> { + for (Long v : values) { + if (v == null) { + b.appendNull(); + } else { + b.appendUnscaledDecimal(v); + } + } + }); + } + /** * Create a new decimal vector from double floats with specific DecimalType and RoundingMode. * All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode. diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 73db5ee4df3..4132016d85c 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -712,6 +713,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_castTo(JNIEnv *env, jclas case cudf::type_id::UINT64: result = cudf::strings::from_integers(*column); break; + case cudf::type_id::DECIMAL32: + case cudf::type_id::DECIMAL64: + result = cudf::strings::from_fixed_point(*column); + break; default: JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0); } } else if (column->type().id() == cudf::type_id::STRING) { @@ -733,6 +738,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_castTo(JNIEnv *env, jclas case cudf::type_id::UINT64: result = cudf::strings::to_integers(*column, n_data_type); break; + case cudf::type_id::DECIMAL32: + case cudf::type_id::DECIMAL64: + result = cudf::strings::to_fixed_point(*column, n_data_type); + break; default: JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0); } } else if (cudf::is_timestamp(n_data_type) && cudf::is_numeric(column->type())) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 420e176efe2..00d6e51fd91 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2226,6 +2226,73 @@ void testCastBoolToString() { testCastFixedWidthToStringsAndBack(DType.BOOL8, () -> ColumnVector.fromBoxedBooleans(booleans), () -> ColumnVector.fromStrings(stringBools)); } + @Test + void testCastDecimal32ToString() { + + Integer[] unScaledValues = {0, null, 3, 2, -43, null, 5234, -73451, 348093, -234810}; + String[] strDecimalValues = new String[unScaledValues.length]; + for (int scale : new int[]{-2, -1, 0, 1, 2}) { + for (int i = 0; i < strDecimalValues.length; i++) { + Long value = unScaledValues[i] == null ? null : Long.valueOf(unScaledValues[i]); + strDecimalValues[i] = dumpDecimal(value, scale); + } + + testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, scale), + () -> ColumnVector.decimalFromBoxedInts(scale, unScaledValues), + () -> ColumnVector.fromStrings(strDecimalValues)); + } + } + + @Test + void testCastDecimal64ToString() { + + Long[] unScaledValues = {0l, null, 3l, 2l, -43l, null, 234802l, -94582l, 1234208124l, -2342348023812l}; + String[] strDecimalValues = new String[unScaledValues.length]; + for (int scale : new int[]{-5, -2, -1, 0, 1, 2, 5}) { + for (int i = 0; i < strDecimalValues.length; i++) { + strDecimalValues[i] = dumpDecimal(unScaledValues[i], scale); + System.out.println(strDecimalValues[i]); + } + + testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL64, scale), + () -> ColumnVector.decimalFromBoxedLongs(scale, unScaledValues), + () -> ColumnVector.fromStrings(strDecimalValues)); + } + } + + /** + * Helper function to create decimal strings which can be processed by castStringToDecimal functor. + * We can not simply create decimal string via `String.valueOf`, because castStringToDecimal doesn't + * support scientific notations so far. + * + * issue for scientific notation: https://github.com/rapidsai/cudf/issues/7665 + */ + private static String dumpDecimal(Long unscaledValue, int scale) { + if (unscaledValue == null) return null; + + StringBuilder builder = new StringBuilder(); + if (unscaledValue < 0) builder.append('-'); + String absValue = String.valueOf(Math.abs(unscaledValue)); + + if (scale >= 0) { + builder.append(absValue); + for (int i = 0; i < scale; i++) builder.append('0'); + return builder.toString(); + } + + if (absValue.length() <= -scale) { + builder.append('0').append('.'); + for (int i = 0; i < -scale - absValue.length(); i++) builder.append('0'); + builder.append(absValue); + } else { + int split = absValue.length() + scale; + builder.append(absValue.substring(0, split)) + .append('.') + .append(absValue.substring(split)); + } + return builder.toString(); + } + private static String[] getStringArray(T[] input) { String[] result = new String[input.length]; for (int i = 0 ; i < input.length ; i++) {