Skip to content

Commit

Permalink
Java: Support struct scalar (#8327)
Browse files Browse the repository at this point in the history
Current PR is to support struct scalar in Java package, which is required by spark-rapids ([issue link](NVIDIA/spark-rapids#2436)). In detail, current PR introduces three new features:
1. create struct scalar through Java API
2. get children of struct scalar through Java API
3. create struct column from (struct) scalar through Java API

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #8327
  • Loading branch information
sperlingxx authored May 26, 2021
1 parent eea8cab commit 2383193
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 4 deletions.
127 changes: 126 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
Expand Down Expand Up @@ -373,6 +373,97 @@ public static Scalar listFromColumnView(ColumnView list) {
return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true));
}

/**
* Creates a null scalar of struct type.
*
* @param elementTypes data types of children in the struct
* @return a null scalar of struct type
*/
public static Scalar structFromNull(HostColumnVector.DataType... elementTypes) {
ColumnVector[] children = new ColumnVector[elementTypes.length];
long[] childHandles = new long[elementTypes.length];
RuntimeException error = null;
try {
for (int i = 0; i < elementTypes.length; i++) {
// Build column vector having single null value rather than empty column vector,
// because struct scalar requires row count of children columns == 1.
children[i] = buildNullColumnVector(elementTypes[i]);
childHandles[i] = children[i].getNativeView();
}
return new Scalar(DType.STRUCT, makeStructScalar(childHandles, false));
} catch (RuntimeException ex) {
error = ex;
throw ex;
} catch (Exception ex) {
error = new RuntimeException(ex);
throw ex;
} finally {
// close all empty children
for (ColumnVector child : children) {
// We closed all created ColumnViews when we hit null. Therefore we exit the loop.
if (child == null) break;
// suppress exception during the close process to ensure that all elements are closed
try {
child.close();
} catch (Exception ex) {
if (error == null) {
error = new RuntimeException(ex);
continue;
}
error.addSuppressed(ex);
}
}
if (error != null) throw error;
}
}

/**
* Creates a scalar of struct from a ColumnView.
*
* @param columns children columns of struct
* @return a Struct scalar
*/
public static Scalar structFromColumnViews(ColumnView... columns) {
if (columns == null) {
throw new IllegalArgumentException("input columns should NOT be null");
}
long[] columnHandles = new long[columns.length];
for (int i = 0; i < columns.length; i++) {
columnHandles[i] = columns[i].getNativeView();
}
return new Scalar(DType.STRUCT, makeStructScalar(columnHandles, true));
}

/**
* Build column vector of single row who holds a null value
*
* @param hostType host data type of null column vector
* @return the null vector
*/
private static ColumnVector buildNullColumnVector(HostColumnVector.DataType hostType) {
DType dt = hostType.getType();
if (!dt.isNestedType()) {
try (HostColumnVector.Builder builder = HostColumnVector.builder(dt, 1)) {
builder.appendNull();
try (HostColumnVector hcv = builder.build()) {
return hcv.copyToDevice();
}
}
} else if (dt.typeId == DType.DTypeEnum.LIST) {
// type of List doesn't matter here because of type erasure in Java
try (HostColumnVector hcv = HostColumnVector.fromLists(hostType, (List<Integer>) null)) {
return hcv.copyToDevice();
}
} else if (dt.typeId == DType.DTypeEnum.STRUCT) {
try (HostColumnVector hcv = HostColumnVector.fromStructs(
hostType, (HostColumnVector.StructData) null)) {
return hcv.copyToDevice();
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + hostType);
}
}

private static native void closeScalar(long scalarHandle);
private static native boolean isScalarValid(long scalarHandle);
private static native byte getByte(long scalarHandle);
Expand All @@ -383,6 +474,7 @@ public static Scalar listFromColumnView(ColumnView list) {
private static native double getDouble(long scalarHandle);
private static native byte[] getUTF8(long scalarHandle);
private static native long getListAsColumnView(long scalarHandle);
private static native long[] getChildrenFromStructScalar(long scalarHandle);
private static native long makeBool8Scalar(boolean isValid, boolean value);
private static native long makeInt8Scalar(byte value, boolean isValid);
private static native long makeUint8Scalar(byte value, boolean isValid);
Expand All @@ -402,6 +494,7 @@ public static Scalar listFromColumnView(ColumnView list) {
private static native long makeDecimal32Scalar(int value, int scale, boolean isValid);
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);


Scalar(DType type, long scalarHandle) {
Expand Down Expand Up @@ -539,6 +632,38 @@ public ColumnView getListAsColumnView() {
return new ColumnView(getListAsColumnView(getScalarHandle()));
}

/**
* Fetches views of children columns from struct scalar.
* The returned ColumnViews should be closed appropriately. Otherwise, a native memory leak will occur.
*
* @return array of column views refer to children of struct scalar
*/
public ColumnView[] getChildrenFromStructScalar() {
assert DType.STRUCT.equals(type) : "Cannot get table for the vector of type " + type;

long[] childHandles = getChildrenFromStructScalar(getScalarHandle());
ColumnView[] children = new ColumnView[childHandles.length];
try {
for (int i = 0; i < children.length; i++) {
children[i] = new ColumnView(childHandles[i]);
}
} catch (Exception ex) {
// close all created ColumnViews if exception thrown
for (ColumnView child : children) {
// We closed all created ColumnViews when we hit null. Therefore we exit the loop.
if (child == null) break;
// make sure the close process is exception-free
try {
child.close();
} catch (Exception suppressed) {
ex.addSuppressed(suppressed);
}
}
throw ex;
}
return children;
}

@Override
public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) {
if (rhs instanceof ColumnView) {
Expand Down
6 changes: 6 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,
std::move(mask_buffer));

col = cudf::fill(str_col->view(), 0, row_count, *scalar_val);
} else if (scalar_val->type().id() == cudf::type_id::STRUCT && row_count == 0) {
// Specialize the creation of empty struct column, since libcudf doesn't support it.
auto struct_scalar = reinterpret_cast<cudf::struct_scalar const *>(j_scalar);
auto children = cudf::empty_like(struct_scalar->view())->release();
auto mask_buffer = cudf::create_null_mask(0, cudf::mask_state::UNALLOCATED);
col = cudf::make_structs_column(0, std::move(children), 0, std::move(mask_buffer));
} else {
col = cudf::make_column_from_scalar(*scalar_val, row_count);
}
Expand Down
37 changes: 37 additions & 0 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_getListAsColumnView(JNIEnv *e
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Scalar_getChildrenFromStructScalar(JNIEnv *env, jclass,
jlong scalar_handle) {
JNI_NULL_CHECK(env, scalar_handle, "scalar handle is null", 0);
try {
cudf::jni::auto_set_device(env);
const auto s = reinterpret_cast<cudf::struct_scalar*>(scalar_handle);
const cudf::table_view& table = s->view();
cudf::jni::native_jpointerArray<cudf::column_view> column_handles(env, table.num_columns());
for (int i = 0; i < table.num_columns(); i++) {
column_handles[i] = new cudf::column_view(table.column(i));
}
return column_handles.get_jArray();
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeBool8Scalar(JNIEnv *env, jclass,
jboolean value,
jboolean is_valid) {
Expand Down Expand Up @@ -477,4 +493,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env, jclass,
jlongArray handles,
jboolean is_valid) {
JNI_NULL_CHECK(env, handles, "native view handles are null", 0)
try {
cudf::jni::auto_set_device(env);
std::unique_ptr<cudf::column_view> ret;
cudf::jni::native_jpointerArray<cudf::column_view> column_pointers(env, handles);
std::vector<cudf::column_view> columns;
columns.reserve(column_pointers.size());
std::transform(column_pointers.data(),
column_pointers.data() + column_pointers.size(),
std::back_inserter(columns),
[](auto const& col_ptr) { return *col_ptr; });
auto s = std::make_unique<cudf::struct_scalar>(
cudf::host_span<cudf::column_view const>{columns}, is_valid);
return reinterpret_cast<jlong>(s.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
83 changes: 81 additions & 2 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1188,8 +1188,14 @@ void testFromScalarZeroRows() {
s = Scalar.durationFromLong(DType.create(type), 21313);
break;
case EMPTY:
case STRUCT:
continue;
case STRUCT:
try (ColumnVector col1 = ColumnVector.fromInts(1);
ColumnVector col2 = ColumnVector.fromStrings("A");
ColumnVector col3 = ColumnVector.fromDoubles(1.23)) {
s = Scalar.structFromColumnViews(col1, col2, col3);
}
break;
case LIST:
try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) {
s = Scalar.listFromColumnView(list);
Expand Down Expand Up @@ -1367,8 +1373,24 @@ void testFromScalar() {
break;
}
case EMPTY:
case STRUCT:
continue;
case STRUCT:
try (ColumnVector col0 = ColumnVector.fromInts(1);
ColumnVector col1 = ColumnVector.fromBoxedDoubles((Double) null);
ColumnVector col2 = ColumnVector.fromStrings("a");
ColumnVector col3 = ColumnVector.fromDecimals(BigDecimal.TEN);
ColumnVector col4 = ColumnVector.daysFromInts(10)) {
s = Scalar.structFromColumnViews(col0, col1, col2, col3, col4);
StructData structData = new StructData(1, null, "a", BigDecimal.TEN, 10);
expected = ColumnVector.fromStructs(new HostColumnVector.StructType(true,
new HostColumnVector.BasicType(true, DType.INT32),
new HostColumnVector.BasicType(true, DType.FLOAT64),
new HostColumnVector.BasicType(true, DType.STRING),
new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, 0)),
new HostColumnVector.BasicType(true, DType.TIMESTAMP_DAYS)),
structData, structData, structData, structData);
}
break;
case LIST:
try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) {
s = Scalar.listFromColumnView(list);
Expand Down Expand Up @@ -1535,6 +1557,63 @@ void testFromScalarListOfStruct() {
}
}

@Test
void testFromScalarNullStruct() {
final int rowCount = 4;
for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) {
DType dType = typeEnum.isDecimalType() ? DType.create(typeEnum, -8) : DType.create(typeEnum);
DataType hDataType;
if (DType.EMPTY.equals(dType)) {
continue;
} else if (DType.LIST.equals(dType)) {
// list of list of int32
hDataType = new ListType(true, new BasicType(true, DType.INT32));
} else if (DType.STRUCT.equals(dType)) {
// list of struct of int32
hDataType = new StructType(true, new BasicType(true, DType.INT32));
} else {
// list of non nested type
hDataType = new BasicType(true, dType);
}
try (Scalar s = Scalar.structFromNull(hDataType, hDataType, hDataType);
ColumnVector c = ColumnVector.fromScalar(s, rowCount);
HostColumnVector hc = c.copyToHost()) {
assertEquals(DType.STRUCT, c.getType());
assertEquals(rowCount, c.getRowCount());
assertEquals(rowCount, c.getNullCount());
for (int i = 0; i < rowCount; ++i) {
assertTrue(hc.isNull(i));
}
assertEquals(3, c.getNumChildren());
ColumnView[] children = new ColumnView[]{c.getChildColumnView(0),
c.getChildColumnView(1), c.getChildColumnView(2)};
try {
for (ColumnView child : children) {
assertEquals(dType, child.getType());
assertEquals(rowCount, child.getRowCount());
assertEquals(rowCount, child.getNullCount());
if (child.getType() == DType.LIST) {
try (ColumnView childOfChild = child.getChildColumnView(0)) {
assertEquals(DType.INT32, childOfChild.getType());
assertEquals(0L, childOfChild.getRowCount());
assertEquals(0L, childOfChild.getNullCount());
}
} else if (child.getType() == DType.STRUCT) {
assertEquals(1, child.getNumChildren());
try (ColumnView childOfChild = child.getChildColumnView(0)) {
assertEquals(DType.INT32, childOfChild.getType());
assertEquals(rowCount, childOfChild.getRowCount());
assertEquals(rowCount, childOfChild.getNullCount());
}
}
}
} finally {
for (ColumnView cv : children) cv.close();
}
}
}
}

@Test
void testReplaceNullsScalarEmptyColumn() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans();
Expand Down
Loading

0 comments on commit 2383193

Please sign in to comment.