Skip to content

Commit

Permalink
Java support of casting string from/to decimal (#7623)
Browse files Browse the repository at this point in the history
This pull request provided Java side support of casting string from/to decimal, which is required by spark-rapids.

Although parsing of string formatted as scientific notation to decimal has not been supported yet by cuDF, we are able to implement string to decimal conversion at spark-rapids side through a two-steps hack: 
1. casting string to float 
2. casting float to decimal

In addition, this pull request also addressed issue #6795.

Authors:
  - Alfred Xu (@sperlingxx)

Approvers:
  - Robert (Bobby) Evans (@revans2)

URL: #7623
  • Loading branch information
sperlingxx authored Mar 23, 2021
1 parent 8632ca0 commit 4e9241e
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 0 deletions.
22 changes: 22 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,17 @@ public static ColumnVector decimalFromInts(int scale, int... values) {
}
}

/**
* Create a new decimal vector from boxed unscaled values (Integer array) and scale.
* The created vector is of type DType.DECIMAL32, whose max precision is 9.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static ColumnVector decimalFromBoxedInts(int scale, Integer... values) {
try (HostColumnVector host = HostColumnVector.decimalFromBoxedInts(scale, values)) {
return host.copyToDevice();
}
}

/**
* Create a new decimal vector from unscaled values (long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
Expand All @@ -1175,6 +1186,17 @@ public static ColumnVector decimalFromLongs(int scale, long... values) {
}
}

/**
* Create a new decimal vector from boxed unscaled values (Long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static ColumnVector decimalFromBoxedLongs(int scale, Long... values) {
try (HostColumnVector host = HostColumnVector.decimalFromBoxedLongs(scale, values)) {
return host.copyToDevice();
}
}

/**
* Create a new decimal vector from double floats with specific DecimalType and RoundingMode.
* All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode.
Expand Down
34 changes: 34 additions & 0 deletions java/src/main/java/ai/rapids/cudf/HostColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,23 @@ public static HostColumnVector decimalFromInts(int scale, int... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values));
}

/**
* Create a new decimal vector from boxed unscaled values (Integer array) and scale.
* The created vector is of type DType.DECIMAL32, whose max precision is 9.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static HostColumnVector decimalFromBoxedInts(int scale, Integer... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> {
for (Integer v : values) {
if (v == null) {
b.appendNull();
} else {
b.appendUnscaledDecimal(v);
}
}
});
}

/**
* Create a new decimal vector from unscaled values (long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
Expand All @@ -490,6 +507,23 @@ public static HostColumnVector decimalFromLongs(int scale, long... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values));
}

/**
* Create a new decimal vector from boxed unscaled values (Long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static HostColumnVector decimalFromBoxedLongs(int scale, Long... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> {
for (Long v : values) {
if (v == null) {
b.appendNull();
} else {
b.appendUnscaledDecimal(v);
}
}
});
}

/**
* Create a new decimal vector from double floats with specific DecimalType and RoundingMode.
* All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode.
Expand Down
9 changes: 9 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cudf/strings/contains.hpp>
#include <cudf/strings/convert/convert_booleans.hpp>
#include <cudf/strings/convert/convert_datetime.hpp>
#include <cudf/strings/convert/convert_fixed_point.hpp>
#include <cudf/strings/convert/convert_floats.hpp>
#include <cudf/strings/convert/convert_integers.hpp>
#include <cudf/strings/convert/convert_urls.hpp>
Expand Down Expand Up @@ -712,6 +713,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_castTo(JNIEnv *env, jclas
case cudf::type_id::UINT64:
result = cudf::strings::from_integers(*column);
break;
case cudf::type_id::DECIMAL32:
case cudf::type_id::DECIMAL64:
result = cudf::strings::from_fixed_point(*column);
break;
default: JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0);
}
} else if (column->type().id() == cudf::type_id::STRING) {
Expand All @@ -733,6 +738,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_castTo(JNIEnv *env, jclas
case cudf::type_id::UINT64:
result = cudf::strings::to_integers(*column, n_data_type);
break;
case cudf::type_id::DECIMAL32:
case cudf::type_id::DECIMAL64:
result = cudf::strings::to_fixed_point(*column, n_data_type);
break;
default: JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0);
}
} else if (cudf::is_timestamp(n_data_type) && cudf::is_numeric(column->type())) {
Expand Down
67 changes: 67 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,73 @@ void testCastBoolToString() {
testCastFixedWidthToStringsAndBack(DType.BOOL8, () -> ColumnVector.fromBoxedBooleans(booleans), () -> ColumnVector.fromStrings(stringBools));
}

@Test
void testCastDecimal32ToString() {

Integer[] unScaledValues = {0, null, 3, 2, -43, null, 5234, -73451, 348093, -234810};
String[] strDecimalValues = new String[unScaledValues.length];
for (int scale : new int[]{-2, -1, 0, 1, 2}) {
for (int i = 0; i < strDecimalValues.length; i++) {
Long value = unScaledValues[i] == null ? null : Long.valueOf(unScaledValues[i]);
strDecimalValues[i] = dumpDecimal(value, scale);
}

testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, scale),
() -> ColumnVector.decimalFromBoxedInts(scale, unScaledValues),
() -> ColumnVector.fromStrings(strDecimalValues));
}
}

@Test
void testCastDecimal64ToString() {

Long[] unScaledValues = {0l, null, 3l, 2l, -43l, null, 234802l, -94582l, 1234208124l, -2342348023812l};
String[] strDecimalValues = new String[unScaledValues.length];
for (int scale : new int[]{-5, -2, -1, 0, 1, 2, 5}) {
for (int i = 0; i < strDecimalValues.length; i++) {
strDecimalValues[i] = dumpDecimal(unScaledValues[i], scale);
System.out.println(strDecimalValues[i]);
}

testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL64, scale),
() -> ColumnVector.decimalFromBoxedLongs(scale, unScaledValues),
() -> ColumnVector.fromStrings(strDecimalValues));
}
}

/**
* Helper function to create decimal strings which can be processed by castStringToDecimal functor.
* We can not simply create decimal string via `String.valueOf`, because castStringToDecimal doesn't
* support scientific notations so far.
*
* issue for scientific notation: https://github.com/rapidsai/cudf/issues/7665
*/
private static String dumpDecimal(Long unscaledValue, int scale) {
if (unscaledValue == null) return null;

StringBuilder builder = new StringBuilder();
if (unscaledValue < 0) builder.append('-');
String absValue = String.valueOf(Math.abs(unscaledValue));

if (scale >= 0) {
builder.append(absValue);
for (int i = 0; i < scale; i++) builder.append('0');
return builder.toString();
}

if (absValue.length() <= -scale) {
builder.append('0').append('.');
for (int i = 0; i < -scale - absValue.length(); i++) builder.append('0');
builder.append(absValue);
} else {
int split = absValue.length() + scale;
builder.append(absValue.substring(0, split))
.append('.')
.append(absValue.substring(split));
}
return builder.toString();
}

private static <T> String[] getStringArray(T[] input) {
String[] result = new String[input.length];
for (int i = 0 ; i < input.length ; i++) {
Expand Down

0 comments on commit 4e9241e

Please sign in to comment.