Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java: Support struct scalar [skip ci] #8327

Merged
merged 8 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
Expand Down Expand Up @@ -373,6 +373,97 @@ public static Scalar listFromColumnView(ColumnView list) {
return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true));
}

/**
* Creates a null scalar of struct type.
*
* @param elementTypes data types of children in the struct
* @return a null scalar of struct type
*/
public static Scalar structFromNull(HostColumnVector.DataType... elementTypes) {
ColumnVector[] children = new ColumnVector[elementTypes.length];
firestarman marked this conversation as resolved.
Show resolved Hide resolved
long[] childHandles = new long[elementTypes.length];
RuntimeException error = null;
try {
for (int i = 0; i < elementTypes.length; i++) {
// Build column vector having single null value rather than empty column vector,
// because struct scalar requires row count of children columns == 1.
children[i] = buildNullColumnVector(elementTypes[i]);
childHandles[i] = children[i].getNativeView();
}
return new Scalar(DType.STRUCT, makeStructScalar(childHandles, false));
} catch (RuntimeException ex) {
error = ex;
throw ex;
} catch (Exception ex) {
error = new RuntimeException(ex);
throw ex;
} finally {
// close all empty children
for (ColumnVector child : children) {
// We closed all created ColumnViews when we hit null. Therefore we exit the loop.
if (child == null) break;
// suppress exception during the close process to ensure that all elements are closed
try {
child.close();
} catch (Exception ex) {
if (error == null) {
error = new RuntimeException(ex);
continue;
}
error.addSuppressed(ex);
}
}
if (error != null) throw error;
}
}

/**
* Creates a scalar of struct from a ColumnView.
*
* @param columns children columns of struct
* @return a Struct scalar
*/
public static Scalar structFromColumnViews(ColumnView... columns) {
if (columns == null) {
throw new IllegalArgumentException("input columns should NOT be null");
}
long[] columnHandles = new long[columns.length];
for (int i = 0; i < columns.length; i++) {
columnHandles[i] = columns[i].getNativeView();
}
return new Scalar(DType.STRUCT, makeStructScalar(columnHandles, true));
}

/**
* Build column vector of single row who holds a null value
*
* @param hostType host data type of null column vector
* @return the null vector
*/
private static ColumnVector buildNullColumnVector(HostColumnVector.DataType hostType) {
DType dt = hostType.getType();
if (!dt.isNestedType()) {
try (HostColumnVector.Builder builder = HostColumnVector.builder(dt, 1)) {
builder.appendNull();
try (HostColumnVector hcv = builder.build()) {
return hcv.copyToDevice();
}
}
} else if (dt.typeId == DType.DTypeEnum.LIST) {
// type of List doesn't matter here because of type erasure in Java
try (HostColumnVector hcv = HostColumnVector.fromLists(hostType, (List<Integer>) null)) {
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
return hcv.copyToDevice();
}
} else if (dt.typeId == DType.DTypeEnum.STRUCT) {
try (HostColumnVector hcv = HostColumnVector.fromStructs(
hostType, (HostColumnVector.StructData) null)) {
return hcv.copyToDevice();
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + hostType);
}
}

private static native void closeScalar(long scalarHandle);
private static native boolean isScalarValid(long scalarHandle);
private static native byte getByte(long scalarHandle);
Expand All @@ -383,6 +474,7 @@ public static Scalar listFromColumnView(ColumnView list) {
private static native double getDouble(long scalarHandle);
private static native byte[] getUTF8(long scalarHandle);
private static native long getListAsColumnView(long scalarHandle);
private static native long[] getChildrenFromStructScalar(long scalarHandle);
private static native long makeBool8Scalar(boolean isValid, boolean value);
private static native long makeInt8Scalar(byte value, boolean isValid);
private static native long makeUint8Scalar(byte value, boolean isValid);
Expand All @@ -402,6 +494,7 @@ public static Scalar listFromColumnView(ColumnView list) {
private static native long makeDecimal32Scalar(int value, int scale, boolean isValid);
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);


Scalar(DType type, long scalarHandle) {
Expand Down Expand Up @@ -539,6 +632,38 @@ public ColumnView getListAsColumnView() {
return new ColumnView(getListAsColumnView(getScalarHandle()));
}

/**
* Fetches views of children columns from struct scalar.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
* The returned ColumnViews should be closed appropriately. Otherwise, a native memory leak will occur.
*
* @return array of column views refer to children of struct scalar
*/
public ColumnView[] getChildrenFromStructScalar() {
assert DType.STRUCT.equals(type) : "Cannot get table for the vector of type " + type;

long[] childHandles = getChildrenFromStructScalar(getScalarHandle());
ColumnView[] children = new ColumnView[childHandles.length];
try {
for (int i = 0; i < children.length; i++) {
children[i] = new ColumnView(childHandles[i]);
}
} catch (Exception ex) {
// close all created ColumnViews if exception thrown
for (ColumnView child : children) {
// We closed all created ColumnViews when we hit null. Therefore we exit the loop.
if (child == null) break;
// make sure the close process is exception-free
try {
child.close();
} catch (Exception suppressed) {
ex.addSuppressed(suppressed);
}
}
throw ex;
}
return children;
}

@Override
public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) {
if (rhs instanceof ColumnView) {
Expand Down
6 changes: 6 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,
std::move(mask_buffer));

col = cudf::fill(str_col->view(), 0, row_count, *scalar_val);
} else if (scalar_val->type().id() == cudf::type_id::STRUCT && row_count == 0) {
// Specialize the creation of empty struct column, since libcudf doesn't support it.
auto struct_scalar = reinterpret_cast<cudf::struct_scalar const *>(j_scalar);
auto children = cudf::empty_like(struct_scalar->view())->release();
auto mask_buffer = cudf::create_null_mask(0, cudf::mask_state::UNALLOCATED);
col = cudf::make_structs_column(0, std::move(children), 0, std::move(mask_buffer));
} else {
col = cudf::make_column_from_scalar(*scalar_val, row_count);
}
Expand Down
37 changes: 37 additions & 0 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_getListAsColumnView(JNIEnv *e
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Scalar_getChildrenFromStructScalar(JNIEnv *env, jclass,
jlong scalar_handle) {
JNI_NULL_CHECK(env, scalar_handle, "scalar handle is null", 0);
try {
cudf::jni::auto_set_device(env);
const auto s = reinterpret_cast<cudf::struct_scalar*>(scalar_handle);
const cudf::table_view& table = s->view();
cudf::jni::native_jpointerArray<cudf::column_view> column_handles(env, table.num_columns());
for (int i = 0; i < table.num_columns(); i++) {
column_handles[i] = new cudf::column_view(table.column(i));
}
return column_handles.get_jArray();
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeBool8Scalar(JNIEnv *env, jclass,
jboolean value,
jboolean is_valid) {
Expand Down Expand Up @@ -477,4 +493,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env, jclass,
jlongArray handles,
jboolean is_valid) {
JNI_NULL_CHECK(env, handles, "native view handles are null", 0)
try {
cudf::jni::auto_set_device(env);
std::unique_ptr<cudf::column_view> ret;
cudf::jni::native_jpointerArray<cudf::column_view> column_pointers(env, handles);
std::vector<cudf::column_view> columns;
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
columns.reserve(column_pointers.size());
std::transform(column_pointers.data(),
column_pointers.data() + column_pointers.size(),
std::back_inserter(columns),
[](auto const& col_ptr) { return *col_ptr; });
auto s = std::make_unique<cudf::struct_scalar>(
cudf::host_span<cudf::column_view const>{columns}, is_valid);
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
return reinterpret_cast<jlong>(s.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
83 changes: 81 additions & 2 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1188,8 +1188,14 @@ void testFromScalarZeroRows() {
s = Scalar.durationFromLong(DType.create(type), 21313);
break;
case EMPTY:
case STRUCT:
continue;
case STRUCT:
try (ColumnVector col1 = ColumnVector.fromInts(1);
ColumnVector col2 = ColumnVector.fromStrings("A");
ColumnVector col3 = ColumnVector.fromDoubles(1.23)) {
s = Scalar.structFromColumnViews(col1, col2, col3);
}
break;
case LIST:
try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) {
s = Scalar.listFromColumnView(list);
Expand Down Expand Up @@ -1367,8 +1373,24 @@ void testFromScalar() {
break;
}
case EMPTY:
case STRUCT:
continue;
case STRUCT:
try (ColumnVector col0 = ColumnVector.fromInts(1);
ColumnVector col1 = ColumnVector.fromBoxedDoubles((Double) null);
ColumnVector col2 = ColumnVector.fromStrings("a");
ColumnVector col3 = ColumnVector.fromDecimals(BigDecimal.TEN);
ColumnVector col4 = ColumnVector.daysFromInts(10)) {
s = Scalar.structFromColumnViews(col0, col1, col2, col3, col4);
StructData structData = new StructData(1, null, "a", BigDecimal.TEN, 10);
expected = ColumnVector.fromStructs(new HostColumnVector.StructType(true,
new HostColumnVector.BasicType(true, DType.INT32),
new HostColumnVector.BasicType(true, DType.FLOAT64),
new HostColumnVector.BasicType(true, DType.STRING),
new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, 0)),
new HostColumnVector.BasicType(true, DType.TIMESTAMP_DAYS)),
structData, structData, structData, structData);
}
break;
case LIST:
try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) {
s = Scalar.listFromColumnView(list);
Expand Down Expand Up @@ -1535,6 +1557,63 @@ void testFromScalarListOfStruct() {
}
}

@Test
void testFromScalarNullStruct() {
final int rowCount = 4;
for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) {
DType dType = typeEnum.isDecimalType() ? DType.create(typeEnum, -8) : DType.create(typeEnum);
DataType hDataType;
if (DType.EMPTY.equals(dType)) {
continue;
} else if (DType.LIST.equals(dType)) {
// list of list of int32
hDataType = new ListType(true, new BasicType(true, DType.INT32));
} else if (DType.STRUCT.equals(dType)) {
// list of struct of int32
hDataType = new StructType(true, new BasicType(true, DType.INT32));
} else {
// list of non nested type
hDataType = new BasicType(true, dType);
}
try (Scalar s = Scalar.structFromNull(hDataType, hDataType, hDataType);
ColumnVector c = ColumnVector.fromScalar(s, rowCount);
HostColumnVector hc = c.copyToHost()) {
assertEquals(DType.STRUCT, c.getType());
assertEquals(rowCount, c.getRowCount());
assertEquals(rowCount, c.getNullCount());
for (int i = 0; i < rowCount; ++i) {
assertTrue(hc.isNull(i));
}
assertEquals(3, c.getNumChildren());
ColumnView[] children = new ColumnView[]{c.getChildColumnView(0),
c.getChildColumnView(1), c.getChildColumnView(2)};
try {
for (ColumnView child : children) {
assertEquals(dType, child.getType());
assertEquals(rowCount, child.getRowCount());
assertEquals(rowCount, child.getNullCount());
if (child.getType() == DType.LIST) {
try (ColumnView childOfChild = child.getChildColumnView(0)) {
assertEquals(DType.INT32, childOfChild.getType());
assertEquals(0L, childOfChild.getRowCount());
assertEquals(0L, childOfChild.getNullCount());
}
} else if (child.getType() == DType.STRUCT) {
assertEquals(1, child.getNumChildren());
try (ColumnView childOfChild = child.getChildColumnView(0)) {
assertEquals(DType.INT32, childOfChild.getType());
assertEquals(rowCount, childOfChild.getRowCount());
assertEquals(rowCount, childOfChild.getNullCount());
}
}
}
} finally {
for (ColumnView cv : children) cv.close();
}
}
}
}

@Test
void testReplaceNullsScalarEmptyColumn() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans();
Expand Down
Loading