From 04d6e5aef09d409351f8485121d5d54d6abbceca Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Fri, 30 Apr 2021 03:24:16 +0800 Subject: [PATCH] JNI support for scalar of list (#8077) This PR is to add the JNI support for scalar of list, along with building a ColumnVector from a list scalar. Since the PR https://github.com/rapidsai/cudf/pull/7584 inroduced the list scalar in cpp. Signed-off-by: Firestarman Authors: - Liangcai Li (https://github.com/firestarman) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/8077 --- .../main/java/ai/rapids/cudf/ColumnView.java | 11 ++- java/src/main/java/ai/rapids/cudf/Scalar.java | 46 +++++++++++ java/src/main/native/src/ColumnVectorJni.cpp | 79 ++++++++++++++++++- java/src/main/native/src/CudaJni.cpp | 9 ++- java/src/main/native/src/ScalarJni.cpp | 34 +++++++- java/src/main/native/src/cudf_jni_apis.hpp | 7 ++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 76 ++++++++++++++++-- .../test/java/ai/rapids/cudf/ScalarTest.java | 33 +++++++- 8 files changed, 283 insertions(+), 12 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 402c64dd83d..f6f4891cdcb 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -41,7 +41,7 @@ public class ColumnView implements AutoCloseable, BinaryOperable { * Constructs a Column View given a native view address * @param address the view handle */ - protected ColumnView(long address) { + ColumnView(long address) { this.viewHandle = address; this.type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle)); this.rows = ColumnView.getNativeRowCount(viewHandle); @@ -211,6 +211,15 @@ public void close() { viewHandle = 0; } + @Override + public String toString() { + return "ColumnView{" + + "rows=" + rows + + ", type=" + type + + ", nullCount=" + nullCount + + '}'; + } + /** * Used for string strip function. * Indicates characters to be stripped from the beginning, end, or both of each string. diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 4221b394826..94aef19128a 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -86,6 +86,8 @@ public static Scalar fromNull(DType type) { return new Scalar(type, makeDecimal32Scalar(0, type.getScale(), false)); case DECIMAL64: return new Scalar(type, makeDecimal64Scalar(0L, type.getScale(), false)); + case LIST: + return new Scalar(type, makeListScalar(0L, false)); default: throw new IllegalArgumentException("Unexpected type: " + type); } @@ -333,6 +335,19 @@ public static Scalar fromString(String value) { return new Scalar(DType.STRING, makeStringScalar(value.getBytes(StandardCharsets.UTF_8), true)); } + /** + * Creates a scalar of list from a ColumnView. + * + * All the rows in the ColumnView will be copied into the Scalar. So the ColumnView + * can be closed after this call completes. + */ + public static Scalar listFromColumnView(ColumnView list) { + if (list == null) { + return Scalar.fromNull(DType.LIST); + } + return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true)); + } + private static native void closeScalar(long scalarHandle); private static native boolean isScalarValid(long scalarHandle); private static native byte getByte(long scalarHandle); @@ -342,6 +357,7 @@ public static Scalar fromString(String value) { private static native float getFloat(long scalarHandle); private static native double getDouble(long scalarHandle); private static native byte[] getUTF8(long scalarHandle); + private static native long getListAsColumnView(long scalarHandle); private static native long makeBool8Scalar(boolean isValid, boolean value); private static native long makeInt8Scalar(byte value, boolean isValid); private static native long makeUint8Scalar(byte value, boolean isValid); @@ -360,6 +376,7 @@ public static Scalar fromString(String value) { private static native long makeTimestampTimeScalar(int dtypeNativeId, long value, boolean isValid); private static native long makeDecimal32Scalar(int value, int scale, boolean isValid); private static native long makeDecimal64Scalar(long value, int scale, boolean isValid); + private static native long makeListScalar(long viewHandle, boolean isValid); Scalar(DType type, long scalarHandle) { @@ -484,6 +501,19 @@ public byte[] getUTF8() { return getUTF8(getScalarHandle()); } + /** + * Returns the scalar value as a ColumnView. Callers should close the returned ColumnView to + * avoid memory leak. + * + * The returned ColumnView is only valid as long as the Scalar remains valid. If the Scalar + * is closed before this ColumnView is closed, using this ColumnView will result in undefined + * behavior. + */ + public ColumnView getListAsColumnView() { + assert DType.LIST.equals(type) : "Cannot get list for the vector of type " + type; + return new ColumnView(getListAsColumnView(getScalarHandle())); + } + @Override public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) { if (rhs instanceof ColumnView) { @@ -541,6 +571,11 @@ public boolean equals(Object o) { return getLong() == other.getLong(); case STRING: return Arrays.equals(getUTF8(), other.getUTF8()); + case LIST: + try (ColumnView viewMe = getListAsColumnView(); + ColumnView viewO = other.getListAsColumnView()) { + return viewMe.equals(viewO); + } default: throw new IllegalStateException("Unexpected type: " + type); } @@ -589,6 +624,11 @@ public int hashCode() { case STRING: valueHash = Arrays.hashCode(getUTF8()); break; + case LIST: + try (ColumnView v = getListAsColumnView()) { + valueHash = v.hashCode(); + } + break; default: throw new IllegalStateException("Unknown scalar type: " + type); } @@ -651,6 +691,12 @@ public String toString() { case DECIMAL64: sb.append(getBigDecimal()); break; + case LIST: + try (ColumnView v = getListAsColumnView()) { + // It's not easy to pull out the elements so just a simple string of some metadata. + sb.append(v.toString()); + } + break; default: throw new IllegalArgumentException("Unknown scalar type: " + type); } diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index ba0e4f05714..858dcf6fd5d 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -31,6 +31,55 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" +namespace cudf { +namespace jni { + /** + * @brief Creates an empty column according to the type tree specified + * in @p view + * + * An empty column contains zero elements and no validity mask. + * + * Unlike the 'cudf::make_empty_column', it takes care of the nested type by + * iterating the children columns in the @p view + * + * @param[in] view The input column view + * @return An empty column for the input column view + */ + std::unique_ptr make_empty_column(JNIEnv *env, cudf::column_view const& view) { + auto tid = view.type().id(); + if (tid == cudf::type_id::LIST) { + // List + if (view.num_children() != 2) { + throw jni_exception("List type requires two children(offset, data)."); + } + // Only needs the second child. + auto data_view = view.child(1); + // offsets: [0] + auto offsets_buffer = rmm::device_buffer(sizeof(cudf::size_type)); + device_memset_async(env, offsets_buffer, 0); + auto offsets = std::make_unique(cudf::data_type{cudf::type_id::INT32}, 1, + std::move(offsets_buffer), + rmm::device_buffer(), 0); + auto data_col = make_empty_column(env, data_view); + return cudf::make_lists_column(0, std::move(offsets), std::move(data_col), + 0, rmm::device_buffer()); + } else if (tid == cudf::type_id::STRUCT) { + // Struct + std::vector> children(view.num_children()); + std::transform(view.child_begin(), view.child_end(), children.begin(), + [env](auto const& child_v) { + return make_empty_column(env, child_v); + }); + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer()); + } else { + // Non nested types + return cudf::make_empty_column(view.type()); + } + } + +} // namespace jni +} // namespace cudf + extern "C" { JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass, @@ -168,6 +217,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, j JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, jclass, jlong j_scalar, jint row_count) { + using ScalarType = cudf::scalar_type_t; JNI_NULL_CHECK(env, j_scalar, "scalar is null", 0); try { cudf::jni::auto_set_device(env); @@ -176,7 +226,34 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, cudf::mask_state mask_state = scalar_val->is_valid() ? cudf::mask_state::UNALLOCATED : cudf::mask_state::ALL_NULL; std::unique_ptr col; - if (row_count == 0) { + if (dtype.id() == cudf::type_id::LIST) { + // Neither 'cudf::make_empty_column' nor 'cudf::make_column_from_scalar' supports + // LIST type for now (https://github.com/rapidsai/cudf/issues/8088), so the list + // precedes the others and takes care of the empty column itself. + auto s_list = reinterpret_cast(scalar_val); + cudf::column_view s_val = s_list->view(); + + // Offsets: [0, list_size, list_size*2, ..., list_szie*row_count] + auto zero = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32)); + auto step = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32)); + zero->set_valid(true); + step->set_valid(true); + static_cast(zero.get())->set_value(0); + static_cast(step.get())->set_value(s_val.size()); + std::unique_ptr offsets = cudf::sequence(row_count + 1, *zero, *step); + // Data: + // Builds the data column by leveraging `cudf::concatenate` to repeat the 's_val' + // 'row_count' times, because 'cudf::make_column_from_scalar' does not support list + // type. + // (Assumes the `row_count` is not big, otherwise there would be a performance issue.) + // Checks the `row_count` because `cudf::concatenate` does not support no columns. + auto data_col = row_count > 0 + ? cudf::concatenate(std::vector(row_count, s_val)) + : cudf::jni::make_empty_column(env, s_val); + col = cudf::make_lists_column(row_count, std::move(offsets), std::move(data_col), + cudf::state_null_count(mask_state, row_count), + cudf::create_null_mask(row_count, mask_state)); + } else if (row_count == 0) { col = cudf::make_empty_column(dtype); } else if (cudf::is_fixed_width(dtype)) { col = cudf::make_fixed_width_column(dtype, row_count, mask_state); diff --git a/java/src/main/native/src/CudaJni.cpp b/java/src/main/native/src/CudaJni.cpp index b41fae21a74..f5eb09fa2d4 100644 --- a/java/src/main/native/src/CudaJni.cpp +++ b/java/src/main/native/src/CudaJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "jni_utils.hpp" namespace { @@ -47,6 +48,12 @@ void auto_set_device(JNIEnv *env) { } } +/** Fills all the bytes in the buffer 'buf' with 'value'. */ +void device_memset_async(JNIEnv *env, rmm::device_buffer& buf, char value) { + cudaError_t cuda_status = cudaMemsetAsync((void *)buf.data(), value, buf.size()); + jni_cuda_check(env, cuda_status); +} + } // namespace jni } // namespace cudf diff --git a/java/src/main/native/src/ScalarJni.cpp b/java/src/main/native/src/ScalarJni.cpp index 4e74cab9328..275b8b051be 100644 --- a/java/src/main/native/src/ScalarJni.cpp +++ b/java/src/main/native/src/ScalarJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -125,6 +125,19 @@ JNIEXPORT jbyteArray JNICALL Java_ai_rapids_cudf_Scalar_getUTF8(JNIEnv *env, jcl CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_getListAsColumnView(JNIEnv *env, jclass, + jlong scalar_handle) { + JNI_NULL_CHECK(env, scalar_handle, "scalar handle is null", 0); + try { + cudf::jni::auto_set_device(env); + auto s = reinterpret_cast(scalar_handle); + // Creates a column view in heap with the stack one, to let JVM take care of its + // life cycle. + return reinterpret_cast(new cudf::column_view(s->view())); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeBool8Scalar(JNIEnv *env, jclass, jboolean value, jboolean is_valid) { @@ -445,4 +458,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_binaryOpSV(JNIEnv *env, jclas CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, jclass, + jlong view_handle, + jboolean is_valid) { + if (is_valid) { + JNI_NULL_CHECK(env, view_handle, + "list scalar is set to `valid` but column view is null", 0); + } + try { + cudf::jni::auto_set_device(env); + auto col_view = reinterpret_cast(view_handle); + + cudf::scalar* s = is_valid ? new cudf::list_scalar(*col_view) + : new cudf::list_scalar(); + s->set_valid(is_valid); + return reinterpret_cast(s); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/main/native/src/cudf_jni_apis.hpp b/java/src/main/native/src/cudf_jni_apis.hpp index 76c7e91d335..14999156890 100644 --- a/java/src/main/native/src/cudf_jni_apis.hpp +++ b/java/src/main/native/src/cudf_jni_apis.hpp @@ -70,5 +70,12 @@ void set_cudf_device(int device); */ void auto_set_device(JNIEnv *env); +/** + * Fills all the bytes in the buffer 'buf' with 'value'. + * The operation has not necessarily completed when this returns, but it could overlap with + * operations occurring on other streams. + */ +void device_memset_async(JNIEnv *env, rmm::device_buffer& buf, char value); + } // namespace jni } // namespace cudf diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 36123704ae6..cca7090b8c7 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1173,10 +1173,14 @@ void testFromScalarZeroRows() { case DURATION_NANOSECONDS: s = Scalar.durationFromLong(DType.create(type), 21313); break; - case EMPTY: - case LIST: - case STRUCT: - continue; + case EMPTY: + case STRUCT: + continue; + case LIST: + try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) { + s = Scalar.listFromColumnView(list); + } + break; default: throw new IllegalArgumentException("Unexpected type: " + type); } @@ -1349,9 +1353,20 @@ void testFromScalar() { break; } case EMPTY: - case LIST: case STRUCT: continue; + case LIST: + try (ColumnVector list = ColumnVector.fromInts(1, 2, 3)) { + s = Scalar.listFromColumnView(list); + expected = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), + Arrays.asList(1, 2, 3), + Arrays.asList(1, 2, 3), + Arrays.asList(1, 2, 3)); + } + break; default: throw new IllegalArgumentException("Unexpected type: " + type); } @@ -1376,7 +1391,7 @@ void testFromScalar() { void testFromScalarNull() { final int rowCount = 4; for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) { - if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.LIST || typeEnum == DType.DTypeEnum.STRUCT) { + if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.STRUCT) { continue; } DType dType; @@ -1413,6 +1428,55 @@ void testFromScalarNullByte() { } } + @Test + void testFromScalarListOfList() { + HostColumnVector.DataType childType = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + HostColumnVector.DataType resultType = new HostColumnVector.ListType(true, childType); + try (ColumnVector list = ColumnVector.fromLists(childType, + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6)); + Scalar s = Scalar.listFromColumnView(list)) { + try (ColumnVector ret = ColumnVector.fromScalar(s, 2); + ColumnVector expected = ColumnVector.fromLists(resultType, + Arrays.asList(Arrays.asList(1, 2, 3),Arrays.asList(4, 5, 6)), + Arrays.asList(Arrays.asList(1, 2, 3),Arrays.asList(4, 5, 6)))) { + assertColumnsAreEqual(expected, ret); + } + // empty row + try (ColumnVector ret = ColumnVector.fromScalar(s, 0)) { + assertEquals(ret.getRowCount(), 0); + assertEquals(ret.getNullCount(), 0); + } + } + } + + @Test + void testFromScalarListOfStruct() { + HostColumnVector.DataType childType = new HostColumnVector.StructType(true, + new HostColumnVector.BasicType(true, DType.INT32), + new HostColumnVector.BasicType(true, DType.STRING)); + HostColumnVector.DataType resultType = new HostColumnVector.ListType(true, childType); + try (ColumnVector list = ColumnVector.fromStructs(childType, + new HostColumnVector.StructData(1, "s1"), + new HostColumnVector.StructData(2, "s2")); + Scalar s = Scalar.listFromColumnView(list)) { + try (ColumnVector ret = ColumnVector.fromScalar(s, 2); + ColumnVector expected = ColumnVector.fromLists(resultType, + Arrays.asList(new HostColumnVector.StructData(1, "s1"), + new HostColumnVector.StructData(2, "s2")), + Arrays.asList(new HostColumnVector.StructData(1, "s1"), + new HostColumnVector.StructData(2, "s2")))) { + assertColumnsAreEqual(expected, ret); + } + // empty row + try (ColumnVector ret = ColumnVector.fromScalar(s, 0)) { + assertEquals(ret.getRowCount(), 0); + assertEquals(ret.getNullCount(), 0); + } + } + } + @Test void testReplaceNullsScalarEmptyColumn() { try (ColumnVector input = ColumnVector.fromBoxedBooleans(); diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index 627171e4b2f..b8141d1601a 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ import org.junit.jupiter.api.Test; import java.math.BigDecimal; +import java.util.Arrays; +import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.*; public class ScalarTest extends CudfTestBase { @@ -54,7 +56,7 @@ public void testNull() { } else { type = DType.create(dataType); } - if (!type.isNestedType()) { + if (!type.isNestedType() || type.equals(DType.LIST)) { try (Scalar s = Scalar.fromNull(type)) { assertEquals(type, s.getType()); assertFalse(s.isValid(), "null validity for " + type); @@ -205,4 +207,31 @@ public void testString() { assertArrayEquals(new byte[]{'T', 'E', 'S', 'T'}, s.getUTF8()); } } + + @Test + public void testList() { + // list of int + try (ColumnVector listInt = ColumnVector.fromInts(1, 2, 3, 4); + Scalar s = Scalar.listFromColumnView(listInt)) { + assertEquals(DType.LIST, s.getType()); + assertTrue(s.isValid()); + try (ColumnView v = s.getListAsColumnView()) { + assertColumnsAreEqual(listInt, v); + } + } + + // list of list + HostColumnVector.DataType listDT = new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)); + try (ColumnVector listList = ColumnVector.fromLists(listDT, + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6)); + Scalar s = Scalar.listFromColumnView(listList)) { + assertEquals(DType.LIST, s.getType()); + assertTrue(s.isValid()); + try (ColumnView v = s.getListAsColumnView()) { + assertColumnsAreEqual(listList, v); + } + } + } }