From 238319368e1207abfaec0fa890503eca4f31e777 Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Wed, 26 May 2021 19:23:11 +0800 Subject: [PATCH] Java: Support struct scalar (#8327) Current PR is to support struct scalar in Java package, which is required by spark-rapids ([issue link](https://github.com/NVIDIA/spark-rapids/issues/2436)). In detail, current PR introduces three new features: 1. create struct scalar through Java API 2. get children of struct scalar through Java API 3. create struct column from (struct) scalar through Java API Authors: - Alfred Xu (https://github.com/sperlingxx) Approvers: - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/8327 --- java/src/main/java/ai/rapids/cudf/Scalar.java | 127 +++++++++++++++++- java/src/main/native/src/ColumnVectorJni.cpp | 6 + java/src/main/native/src/ScalarJni.cpp | 37 +++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 83 +++++++++++- .../test/java/ai/rapids/cudf/ScalarTest.java | 122 ++++++++++++++++- 5 files changed, 371 insertions(+), 4 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 7794b57c3f9..925cc89a51a 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -22,9 +22,9 @@ import org.slf4j.LoggerFactory; import java.math.BigDecimal; -import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.List; import java.util.Objects; /** @@ -373,6 +373,97 @@ public static Scalar listFromColumnView(ColumnView list) { return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true)); } + /** + * Creates a null scalar of struct type. + * + * @param elementTypes data types of children in the struct + * @return a null scalar of struct type + */ + public static Scalar structFromNull(HostColumnVector.DataType... elementTypes) { + ColumnVector[] children = new ColumnVector[elementTypes.length]; + long[] childHandles = new long[elementTypes.length]; + RuntimeException error = null; + try { + for (int i = 0; i < elementTypes.length; i++) { + // Build column vector having single null value rather than empty column vector, + // because struct scalar requires row count of children columns == 1. + children[i] = buildNullColumnVector(elementTypes[i]); + childHandles[i] = children[i].getNativeView(); + } + return new Scalar(DType.STRUCT, makeStructScalar(childHandles, false)); + } catch (RuntimeException ex) { + error = ex; + throw ex; + } catch (Exception ex) { + error = new RuntimeException(ex); + throw ex; + } finally { + // close all empty children + for (ColumnVector child : children) { + // We closed all created ColumnViews when we hit null. Therefore we exit the loop. + if (child == null) break; + // suppress exception during the close process to ensure that all elements are closed + try { + child.close(); + } catch (Exception ex) { + if (error == null) { + error = new RuntimeException(ex); + continue; + } + error.addSuppressed(ex); + } + } + if (error != null) throw error; + } + } + + /** + * Creates a scalar of struct from a ColumnView. + * + * @param columns children columns of struct + * @return a Struct scalar + */ + public static Scalar structFromColumnViews(ColumnView... columns) { + if (columns == null) { + throw new IllegalArgumentException("input columns should NOT be null"); + } + long[] columnHandles = new long[columns.length]; + for (int i = 0; i < columns.length; i++) { + columnHandles[i] = columns[i].getNativeView(); + } + return new Scalar(DType.STRUCT, makeStructScalar(columnHandles, true)); + } + + /** + * Build column vector of single row who holds a null value + * + * @param hostType host data type of null column vector + * @return the null vector + */ + private static ColumnVector buildNullColumnVector(HostColumnVector.DataType hostType) { + DType dt = hostType.getType(); + if (!dt.isNestedType()) { + try (HostColumnVector.Builder builder = HostColumnVector.builder(dt, 1)) { + builder.appendNull(); + try (HostColumnVector hcv = builder.build()) { + return hcv.copyToDevice(); + } + } + } else if (dt.typeId == DType.DTypeEnum.LIST) { + // type of List doesn't matter here because of type erasure in Java + try (HostColumnVector hcv = HostColumnVector.fromLists(hostType, (List) null)) { + return hcv.copyToDevice(); + } + } else if (dt.typeId == DType.DTypeEnum.STRUCT) { + try (HostColumnVector hcv = HostColumnVector.fromStructs( + hostType, (HostColumnVector.StructData) null)) { + return hcv.copyToDevice(); + } + } else { + throw new IllegalArgumentException("Unsupported data type: " + hostType); + } + } + private static native void closeScalar(long scalarHandle); private static native boolean isScalarValid(long scalarHandle); private static native byte getByte(long scalarHandle); @@ -383,6 +474,7 @@ public static Scalar listFromColumnView(ColumnView list) { private static native double getDouble(long scalarHandle); private static native byte[] getUTF8(long scalarHandle); private static native long getListAsColumnView(long scalarHandle); + private static native long[] getChildrenFromStructScalar(long scalarHandle); private static native long makeBool8Scalar(boolean isValid, boolean value); private static native long makeInt8Scalar(byte value, boolean isValid); private static native long makeUint8Scalar(byte value, boolean isValid); @@ -402,6 +494,7 @@ public static Scalar listFromColumnView(ColumnView list) { private static native long makeDecimal32Scalar(int value, int scale, boolean isValid); private static native long makeDecimal64Scalar(long value, int scale, boolean isValid); private static native long makeListScalar(long viewHandle, boolean isValid); + private static native long makeStructScalar(long[] viewHandles, boolean isValid); Scalar(DType type, long scalarHandle) { @@ -539,6 +632,38 @@ public ColumnView getListAsColumnView() { return new ColumnView(getListAsColumnView(getScalarHandle())); } + /** + * Fetches views of children columns from struct scalar. + * The returned ColumnViews should be closed appropriately. Otherwise, a native memory leak will occur. + * + * @return array of column views refer to children of struct scalar + */ + public ColumnView[] getChildrenFromStructScalar() { + assert DType.STRUCT.equals(type) : "Cannot get table for the vector of type " + type; + + long[] childHandles = getChildrenFromStructScalar(getScalarHandle()); + ColumnView[] children = new ColumnView[childHandles.length]; + try { + for (int i = 0; i < children.length; i++) { + children[i] = new ColumnView(childHandles[i]); + } + } catch (Exception ex) { + // close all created ColumnViews if exception thrown + for (ColumnView child : children) { + // We closed all created ColumnViews when we hit null. Therefore we exit the loop. + if (child == null) break; + // make sure the close process is exception-free + try { + child.close(); + } catch (Exception suppressed) { + ex.addSuppressed(suppressed); + } + } + throw ex; + } + return children; + } + @Override public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) { if (rhs instanceof ColumnView) { diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 2953a6221e8..97ea9c51512 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -237,6 +237,12 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, std::move(mask_buffer)); col = cudf::fill(str_col->view(), 0, row_count, *scalar_val); + } else if (scalar_val->type().id() == cudf::type_id::STRUCT && row_count == 0) { + // Specialize the creation of empty struct column, since libcudf doesn't support it. + auto struct_scalar = reinterpret_cast(j_scalar); + auto children = cudf::empty_like(struct_scalar->view())->release(); + auto mask_buffer = cudf::create_null_mask(0, cudf::mask_state::UNALLOCATED); + col = cudf::make_structs_column(0, std::move(children), 0, std::move(mask_buffer)); } else { col = cudf::make_column_from_scalar(*scalar_val, row_count); } diff --git a/java/src/main/native/src/ScalarJni.cpp b/java/src/main/native/src/ScalarJni.cpp index 95f934ff91b..8939c77f234 100644 --- a/java/src/main/native/src/ScalarJni.cpp +++ b/java/src/main/native/src/ScalarJni.cpp @@ -138,6 +138,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_getListAsColumnView(JNIEnv *e CATCH_STD(env, 0); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Scalar_getChildrenFromStructScalar(JNIEnv *env, jclass, + jlong scalar_handle) { + JNI_NULL_CHECK(env, scalar_handle, "scalar handle is null", 0); + try { + cudf::jni::auto_set_device(env); + const auto s = reinterpret_cast(scalar_handle); + const cudf::table_view& table = s->view(); + cudf::jni::native_jpointerArray column_handles(env, table.num_columns()); + for (int i = 0; i < table.num_columns(); i++) { + column_handles[i] = new cudf::column_view(table.column(i)); + } + return column_handles.get_jArray(); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeBool8Scalar(JNIEnv *env, jclass, jboolean value, jboolean is_valid) { @@ -477,4 +493,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, j CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env, jclass, + jlongArray handles, + jboolean is_valid) { + JNI_NULL_CHECK(env, handles, "native view handles are null", 0) + try { + cudf::jni::auto_set_device(env); + std::unique_ptr ret; + cudf::jni::native_jpointerArray column_pointers(env, handles); + std::vector columns; + columns.reserve(column_pointers.size()); + std::transform(column_pointers.data(), + column_pointers.data() + column_pointers.size(), + std::back_inserter(columns), + [](auto const& col_ptr) { return *col_ptr; }); + auto s = std::make_unique( + cudf::host_span{columns}, is_valid); + return reinterpret_cast(s.release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 8da70afc6f3..e04462e138b 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1188,8 +1188,14 @@ void testFromScalarZeroRows() { s = Scalar.durationFromLong(DType.create(type), 21313); break; case EMPTY: - case STRUCT: continue; + case STRUCT: + try (ColumnVector col1 = ColumnVector.fromInts(1); + ColumnVector col2 = ColumnVector.fromStrings("A"); + ColumnVector col3 = ColumnVector.fromDoubles(1.23)) { + s = Scalar.structFromColumnViews(col1, col2, col3); + } + break; case LIST: try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) { s = Scalar.listFromColumnView(list); @@ -1367,8 +1373,24 @@ void testFromScalar() { break; } case EMPTY: - case STRUCT: continue; + case STRUCT: + try (ColumnVector col0 = ColumnVector.fromInts(1); + ColumnVector col1 = ColumnVector.fromBoxedDoubles((Double) null); + ColumnVector col2 = ColumnVector.fromStrings("a"); + ColumnVector col3 = ColumnVector.fromDecimals(BigDecimal.TEN); + ColumnVector col4 = ColumnVector.daysFromInts(10)) { + s = Scalar.structFromColumnViews(col0, col1, col2, col3, col4); + StructData structData = new StructData(1, null, "a", BigDecimal.TEN, 10); + expected = ColumnVector.fromStructs(new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.FLOAT64), + new HostColumnVector.BasicType(true, DType.STRING), + new HostColumnVector.BasicType(true, DType.create(DType.DTypeEnum.DECIMAL32, 0)), + new HostColumnVector.BasicType(true, DType.TIMESTAMP_DAYS)), + structData, structData, structData, structData); + } + break; case LIST: try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) { s = Scalar.listFromColumnView(list); @@ -1535,6 +1557,63 @@ void testFromScalarListOfStruct() { } } + @Test + void testFromScalarNullStruct() { + final int rowCount = 4; + for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) { + DType dType = typeEnum.isDecimalType() ? DType.create(typeEnum, -8) : DType.create(typeEnum); + DataType hDataType; + if (DType.EMPTY.equals(dType)) { + continue; + } else if (DType.LIST.equals(dType)) { + // list of list of int32 + hDataType = new ListType(true, new BasicType(true, DType.INT32)); + } else if (DType.STRUCT.equals(dType)) { + // list of struct of int32 + hDataType = new StructType(true, new BasicType(true, DType.INT32)); + } else { + // list of non nested type + hDataType = new BasicType(true, dType); + } + try (Scalar s = Scalar.structFromNull(hDataType, hDataType, hDataType); + ColumnVector c = ColumnVector.fromScalar(s, rowCount); + HostColumnVector hc = c.copyToHost()) { + assertEquals(DType.STRUCT, c.getType()); + assertEquals(rowCount, c.getRowCount()); + assertEquals(rowCount, c.getNullCount()); + for (int i = 0; i < rowCount; ++i) { + assertTrue(hc.isNull(i)); + } + assertEquals(3, c.getNumChildren()); + ColumnView[] children = new ColumnView[]{c.getChildColumnView(0), + c.getChildColumnView(1), c.getChildColumnView(2)}; + try { + for (ColumnView child : children) { + assertEquals(dType, child.getType()); + assertEquals(rowCount, child.getRowCount()); + assertEquals(rowCount, child.getNullCount()); + if (child.getType() == DType.LIST) { + try (ColumnView childOfChild = child.getChildColumnView(0)) { + assertEquals(DType.INT32, childOfChild.getType()); + assertEquals(0L, childOfChild.getRowCount()); + assertEquals(0L, childOfChild.getNullCount()); + } + } else if (child.getType() == DType.STRUCT) { + assertEquals(1, child.getNumChildren()); + try (ColumnView childOfChild = child.getChildColumnView(0)) { + assertEquals(DType.INT32, childOfChild.getType()); + assertEquals(rowCount, childOfChild.getRowCount()); + assertEquals(rowCount, childOfChild.getNullCount()); + } + } + } + } finally { + for (ColumnView cv : children) cv.close(); + } + } + } + } + @Test void testReplaceNullsScalarEmptyColumn() { try (ColumnVector input = ColumnVector.fromBoxedBooleans(); diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index a1078f2546b..00de3a696ad 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -70,7 +70,7 @@ public void testNull() { } } - // list scalar + // create elementType for nested types HostColumnVector.DataType hDataType; if (DType.EMPTY.equals(type)) { continue; @@ -84,6 +84,8 @@ public void testNull() { // list of non nested type hDataType = new BasicType(true, type); } + + // test list scalar with elementType(`type`) try (Scalar s = Scalar.listFromNull(hDataType); ColumnView listCv = s.getListAsColumnView()) { assertFalse(s.isValid(), "null validity for " + type); @@ -99,6 +101,23 @@ public void testNull() { } } } + + // test struct scalar with elementType(`type`) + try (Scalar s = Scalar.structFromNull(hDataType, hDataType, hDataType)) { + assertFalse(s.isValid(), "null validity for " + type); + assertEquals(DType.STRUCT, s.getType()); + + ColumnView[] children = s.getChildrenFromStructScalar(); + try { + for (ColumnView child : children) { + assertEquals(hDataType.getType(), child.getType()); + assertEquals(1L, child.getRowCount()); + assertEquals(1L, child.getNullCount()); + } + } finally { + for (ColumnView child : children) child.close(); + } + } } } @@ -287,4 +306,105 @@ public void testList() { } } } + + @Test + public void testStruct() { + try (ColumnVector col0 = ColumnVector.fromInts(1); + ColumnVector col1 = ColumnVector.fromBoxedDoubles(1.2); + ColumnVector col2 = ColumnVector.fromStrings("a"); + ColumnVector col3 = ColumnVector.fromDecimals(BigDecimal.TEN); + ColumnVector col4 = ColumnVector.daysFromInts(10); + ColumnVector col5 = ColumnVector.durationSecondsFromLongs(12345L); + Scalar s = Scalar.structFromColumnViews(col0, col1, col2, col3, col4, col5, col0, col1)) { + assertEquals(DType.STRUCT, s.getType()); + assertTrue(s.isValid()); + ColumnView[] children = s.getChildrenFromStructScalar(); + try { + assertColumnsAreEqual(col0, children[0]); + assertColumnsAreEqual(col1, children[1]); + assertColumnsAreEqual(col2, children[2]); + assertColumnsAreEqual(col3, children[3]); + assertColumnsAreEqual(col4, children[4]); + assertColumnsAreEqual(col5, children[5]); + assertColumnsAreEqual(col0, children[6]); + assertColumnsAreEqual(col1, children[7]); + } finally { + for (ColumnView child : children) child.close(); + } + } + + // test Struct Scalar with null members + try (ColumnVector col0 = ColumnVector.fromInts(1); + ColumnVector col1 = ColumnVector.fromBoxedDoubles((Double) null); + ColumnVector col2 = ColumnVector.fromStrings((String) null); + Scalar s1 = Scalar.structFromColumnViews(col0, col1, col2); + Scalar s2 = Scalar.structFromColumnViews(col1, col2)) { + ColumnView[] children = s1.getChildrenFromStructScalar(); + try { + assertColumnsAreEqual(col0, children[0]); + assertColumnsAreEqual(col1, children[1]); + assertColumnsAreEqual(col2, children[2]); + } finally { + for (ColumnView child : children) child.close(); + } + + ColumnView[] children2 = s2.getChildrenFromStructScalar(); + try { + assertColumnsAreEqual(col1, children2[0]); + assertColumnsAreEqual(col2, children2[1]); + } finally { + for (ColumnView child : children2) child.close(); + } + } + + // test Struct Scalar with single column + try (ColumnVector col0 = ColumnVector.fromInts(1234); + Scalar s = Scalar.structFromColumnViews(col0)) { + ColumnView[] children = s.getChildrenFromStructScalar(); + try { + assertColumnsAreEqual(col0, children[0]); + } finally { + children[0].close(); + } + } + + // test Struct Scalar without column + try (Scalar s = Scalar.structFromColumnViews()) { + assertEquals(DType.STRUCT, s.getType()); + assertTrue(s.isValid()); + ColumnView[] children = s.getChildrenFromStructScalar(); + assertEquals(0, children.length); + } + + // test Struct Scalar with nested types + HostColumnVector.DataType listType = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + HostColumnVector.DataType structType = new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.INT64)); + HostColumnVector.DataType nestedStructType = new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(true, DType.STRING), + listType, structType); + try (ColumnVector strCol = ColumnVector.fromStrings("AAAAAA"); + ColumnVector listCol = ColumnVector.fromLists(listType, Arrays.asList(1, 2, 3, 4, 5)); + ColumnVector structCol = ColumnVector.fromStructs(structType, + new HostColumnVector.StructData(1, -1L)); + ColumnVector nestedStructCol = ColumnVector.fromStructs(nestedStructType, + new HostColumnVector.StructData(null, + Arrays.asList(1, 2, null), + new HostColumnVector.StructData(null, 10L))); + Scalar s = Scalar.structFromColumnViews(strCol, listCol, structCol, nestedStructCol)) { + assertEquals(DType.STRUCT, s.getType()); + assertTrue(s.isValid()); + ColumnView[] children = s.getChildrenFromStructScalar(); + try { + assertColumnsAreEqual(strCol, children[0]); + assertColumnsAreEqual(listCol, children[1]); + assertColumnsAreEqual(structCol, children[2]); + assertColumnsAreEqual(nestedStructCol, children[3]); + } finally { + for (ColumnView child : children) child.close(); + } + } + } }