Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java support of casting string from/to decimal [skip ci] #7623

Merged
merged 2 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,17 @@ public static ColumnVector decimalFromInts(int scale, int... values) {
}
}

/**
* Create a new decimal vector from boxed unscaled values (Integer array) and scale.
* The created vector is of type DType.DECIMAL32, whose max precision is 9.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static ColumnVector decimalFromBoxedInts(int scale, Integer... values) {
try (HostColumnVector host = HostColumnVector.decimalFromBoxedInts(scale, values)) {
return host.copyToDevice();
}
}

/**
* Create a new decimal vector from unscaled values (long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
Expand All @@ -1175,6 +1186,17 @@ public static ColumnVector decimalFromLongs(int scale, long... values) {
}
}

/**
* Create a new decimal vector from boxed unscaled values (Long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static ColumnVector decimalFromBoxedLongs(int scale, Long... values) {
try (HostColumnVector host = HostColumnVector.decimalFromBoxedLongs(scale, values)) {
return host.copyToDevice();
}
}

/**
* Create a new decimal vector from double floats with specific DecimalType and RoundingMode.
* All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode.
Expand Down
34 changes: 34 additions & 0 deletions java/src/main/java/ai/rapids/cudf/HostColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,23 @@ public static HostColumnVector decimalFromInts(int scale, int... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values));
}

/**
* Create a new decimal vector from boxed unscaled values (Integer array) and scale.
* The created vector is of type DType.DECIMAL32, whose max precision is 9.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static HostColumnVector decimalFromBoxedInts(int scale, Integer... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL32, scale), values.length, (b) -> {
for (Integer v : values) {
if (v == null) {
b.appendNull();
} else {
b.appendUnscaledDecimal(v);
}
}
});
}

/**
* Create a new decimal vector from unscaled values (long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
Expand All @@ -490,6 +507,23 @@ public static HostColumnVector decimalFromLongs(int scale, long... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> b.appendUnscaledDecimalArray(values));
}

/**
* Create a new decimal vector from boxed unscaled values (Long array) and scale.
* The created vector is of type DType.DECIMAL64, whose max precision is 18.
* Compared with scale of [[java.math.BigDecimal]], the scale here represents the opposite meaning.
*/
public static HostColumnVector decimalFromBoxedLongs(int scale, Long... values) {
return build(DType.create(DType.DTypeEnum.DECIMAL64, scale), values.length, (b) -> {
for (Long v : values) {
if (v == null) {
b.appendNull();
} else {
b.appendUnscaledDecimal(v);
}
}
});
}

/**
* Create a new decimal vector from double floats with specific DecimalType and RoundingMode.
* All doubles will be rescaled if necessary, according to scale of input DecimalType and RoundingMode.
Expand Down
9 changes: 9 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cudf/strings/contains.hpp>
#include <cudf/strings/convert/convert_booleans.hpp>
#include <cudf/strings/convert/convert_datetime.hpp>
#include <cudf/strings/convert/convert_fixed_point.hpp>
#include <cudf/strings/convert/convert_floats.hpp>
#include <cudf/strings/convert/convert_integers.hpp>
#include <cudf/strings/convert/convert_urls.hpp>
Expand Down Expand Up @@ -698,6 +699,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_castTo(JNIEnv *env, jclas
case cudf::type_id::UINT64:
result = cudf::strings::from_integers(*column);
break;
case cudf::type_id::DECIMAL32:
case cudf::type_id::DECIMAL64:
result = cudf::strings::from_fixed_point(*column);
break;
default: JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0);
}
} else if (column->type().id() == cudf::type_id::STRING) {
Expand All @@ -719,6 +724,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_castTo(JNIEnv *env, jclas
case cudf::type_id::UINT64:
result = cudf::strings::to_integers(*column, n_data_type);
break;
case cudf::type_id::DECIMAL32:
case cudf::type_id::DECIMAL64:
result = cudf::strings::to_fixed_point(*column, n_data_type);
break;
default: JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid data type", 0);
}
} else if (cudf::is_timestamp(n_data_type) && cudf::is_numeric(column->type())) {
Expand Down
65 changes: 65 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2214,6 +2214,71 @@ void testCastBoolToString() {
testCastFixedWidthToStringsAndBack(DType.BOOL8, () -> ColumnVector.fromBoxedBooleans(booleans), () -> ColumnVector.fromStrings(stringBools));
}

@Test
void testCastDecimal32ToString() {

Integer[] unScaledValues = {0, null, 3, 2, -43, null, 5234, -73451, 348093, -234810};
String[] strDecimalValues = new String[unScaledValues.length];
for (int scale : new int[]{-2, -1, 0, 1, 2}) {
for (int i = 0; i < strDecimalValues.length; i++) {
Long value = unScaledValues[i] == null ? null : Long.valueOf(unScaledValues[i]);
strDecimalValues[i] = dumpDecimal(value, scale);
}

testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL32, scale),
() -> ColumnVector.decimalFromBoxedInts(scale, unScaledValues),
() -> ColumnVector.fromStrings(strDecimalValues));
}
}

@Test
void testCastDecimal64ToString() {

Long[] unScaledValues = {0l, null, 3l, 2l, -43l, null, 234802l, -94582l, 1234208124l, -2342348023812l};
String[] strDecimalValues = new String[unScaledValues.length];
for (int scale : new int[]{-5, -2, -1, 0, 1, 2, 5}) {
for (int i = 0; i < strDecimalValues.length; i++) {
strDecimalValues[i] = dumpDecimal(unScaledValues[i], scale);
System.out.println(strDecimalValues[i]);
}

testCastFixedWidthToStringsAndBack(DType.create(DType.DTypeEnum.DECIMAL64, scale),
() -> ColumnVector.decimalFromBoxedLongs(scale, unScaledValues),
() -> ColumnVector.fromStrings(strDecimalValues));
}
}

/**
* Helper function to create decimal strings which can be processed by castStringToDecimal functor.
* We can not simply create decimal string via `String.valueOf`, because castStringToDecimal doesn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please file a follow on issue for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've filed issue for scientific notation. And appended the issue link here.

* support scientific notations so far.
*/
private static String dumpDecimal(Long unscaledValue, int scale) {
if (unscaledValue == null) return null;

StringBuilder builder = new StringBuilder();
if (unscaledValue < 0) builder.append('-');
String absValue = String.valueOf(Math.abs(unscaledValue));

if (scale >= 0) {
builder.append(absValue);
for (int i = 0; i < scale; i++) builder.append('0');
return builder.toString();
}

if (absValue.length() <= -scale) {
builder.append('0').append('.');
for (int i = 0; i < -scale - absValue.length(); i++) builder.append('0');
builder.append(absValue);
} else {
int split = absValue.length() + scale;
builder.append(absValue.substring(0, split))
.append('.')
.append(absValue.substring(split));
}
return builder.toString();
}

private static <T> String[] getStringArray(T[] input) {
String[] result = new String[input.length];
for (int i = 0 ; i < input.length ; i++) {
Expand Down