diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 7756d7d7ce4..adf0f317340 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1489,4 +1489,48 @@ public static ColumnVector timestampNanoSecondsFromBoxedLongs(Long... values) { return build(DType.TIMESTAMP_NANOSECONDS, values.length, (b) -> b.appendBoxed(values)); } + /** + * Creates an empty column according to the data type. + * + * It will create all the nested columns by iterating all the children in the input + * type object 'colType'. + * + * The performance is not good, so use it carefully. We may want to move this implementation + * to the native once figuring out a way to pass the nested data type to the native. + * + * @param colType the data type of the empty column + * @return an empty ColumnVector with its children. Each children contains zero elements. + * Users should close the ColumnVector to avoid memory leak. + */ + public static ColumnVector empty(HostColumnVector.DataType colType) { + if (colType == null || colType.getType() == null) { + throw new IllegalArgumentException("The data type and its 'DType' should NOT be null."); + } + if (colType instanceof HostColumnVector.BasicType) { + // Non nested type + DType dt = colType.getType(); + return new ColumnVector(makeEmptyCudfColumn(dt.typeId.getNativeId(), dt.getScale())); + } else if (colType instanceof HostColumnVector.ListType) { + // List type + assert colType.getNumChildren() == 1 : "List type requires one child type"; + try (ColumnVector child = empty(colType.getChild(0))) { + return makeList(child); + } + } else if (colType instanceof HostColumnVector.StructType) { + // Struct type + ColumnVector[] children = new ColumnVector[colType.getNumChildren()]; + try { + for (int i = 0; i < children.length; i++) { + children[i] = empty(colType.getChild(i)); + } + return makeStruct(children); + } finally { + for (ColumnVector cv : children) { + if (cv != null) cv.close(); + } + } + } else { + throw new IllegalArgumentException("Unsupported data type: " + colType); + } + } } diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 94aef19128a..ec20f39af27 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -87,7 +87,7 @@ public static Scalar fromNull(DType type) { case DECIMAL64: return new Scalar(type, makeDecimal64Scalar(0L, type.getScale(), false)); case LIST: - return new Scalar(type, makeListScalar(0L, false)); + throw new IllegalArgumentException("Please call 'listFromNull' to create a null list scalar."); default: throw new IllegalArgumentException("Unexpected type: " + type); } @@ -335,6 +335,21 @@ public static Scalar fromString(String value) { return new Scalar(DType.STRING, makeStringScalar(value.getBytes(StandardCharsets.UTF_8), true)); } + /** + * Creates a null scalar of list type. + * + * Having this special API because the element type is required to build an empty + * nested column as the underlying column of the list scalar. + * + * @param elementType the data type of the element in the list. + * @return a null scalar of list type + */ + public static Scalar listFromNull(HostColumnVector.DataType elementType) { + try (ColumnVector col = ColumnVector.empty(elementType)) { + return new Scalar(DType.LIST, makeListScalar(col.getNativeView(), false)); + } + } + /** * Creates a scalar of list from a ColumnView. * @@ -343,7 +358,8 @@ public static Scalar fromString(String value) { */ public static Scalar listFromColumnView(ColumnView list) { if (list == null) { - return Scalar.fromNull(DType.LIST); + throw new IllegalArgumentException("'list' should NOT be null." + + " Please call 'listFromNull' to create a null list scalar."); } return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true)); } diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index f9efba673c6..85bbdd41b4a 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -33,55 +33,6 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" -namespace cudf { -namespace jni { - /** - * @brief Creates an empty column according to the type tree specified - * in @p view - * - * An empty column contains zero elements and no validity mask. - * - * Unlike the 'cudf::make_empty_column', it takes care of the nested type by - * iterating the children columns in the @p view - * - * @param[in] view The input column view - * @return An empty column for the input column view - */ - std::unique_ptr make_empty_column(JNIEnv *env, cudf::column_view const& view) { - auto tid = view.type().id(); - if (tid == cudf::type_id::LIST) { - // List - if (view.num_children() != 2) { - throw jni_exception("List type requires two children(offset, data)."); - } - // Only needs the second child. - auto data_view = view.child(1); - // offsets: [0] - auto offsets_buffer = rmm::device_buffer(sizeof(cudf::size_type)); - device_memset_async(env, offsets_buffer, 0); - auto offsets = std::make_unique(cudf::data_type{cudf::type_id::INT32}, 1, - std::move(offsets_buffer), - rmm::device_buffer(), 0); - auto data_col = make_empty_column(env, data_view); - return cudf::make_lists_column(0, std::move(offsets), std::move(data_col), - 0, rmm::device_buffer()); - } else if (tid == cudf::type_id::STRUCT) { - // Struct - std::vector> children(view.num_children()); - std::transform(view.child_begin(), view.child_end(), children.begin(), - [env](auto const& child_v) { - return make_empty_column(env, child_v); - }); - return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer()); - } else { - // Non nested types - return cudf::make_empty_column(view.type()); - } - } - -} // namespace jni -} // namespace cudf - extern "C" { JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass, @@ -298,10 +249,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, // 'row_count' times, because 'cudf::make_column_from_scalar' does not support list // type. // (Assumes the `row_count` is not big, otherwise there would be a performance issue.) - // Checks the `row_count` because `cudf::concatenate` does not support no columns. + // Checks the `row_count` because `cudf::concatenate` does not support no rows. auto data_col = row_count > 0 ? cudf::concatenate(std::vector(row_count, s_val)) - : cudf::jni::make_empty_column(env, s_val); + : cudf::empty_like(s_val); col = cudf::make_lists_column(row_count, std::move(offsets), std::move(data_col), cudf::state_null_count(mask_state, row_count), cudf::create_null_mask(row_count, mask_state)); diff --git a/java/src/main/native/src/ScalarJni.cpp b/java/src/main/native/src/ScalarJni.cpp index 275b8b051be..95f934ff91b 100644 --- a/java/src/main/native/src/ScalarJni.cpp +++ b/java/src/main/native/src/ScalarJni.cpp @@ -461,16 +461,16 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_binaryOpSV(JNIEnv *env, jclas JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, jclass, jlong view_handle, jboolean is_valid) { - if (is_valid) { - JNI_NULL_CHECK(env, view_handle, - "list scalar is set to `valid` but column view is null", 0); - } + JNI_NULL_CHECK(env, view_handle, "Column view should NOT be null", 0); try { cudf::jni::auto_set_device(env); auto col_view = reinterpret_cast(view_handle); - cudf::scalar* s = is_valid ? new cudf::list_scalar(*col_view) - : new cudf::list_scalar(); + // Instead of calling the `cudf::empty_like` to create an empty column when `is_valid` + // is false, always passes the input view to the scalar, to avoid copying the column + // twice. + // Let the Java layer make sure the view is empty when `is_valid` is false. + cudf::scalar* s = new cudf::list_scalar(*col_view); s->set_valid(is_valid); return reinterpret_cast(s); } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 0349cb93f68..4c5ee7295d9 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1391,7 +1391,7 @@ void testFromScalar() { void testFromScalarNull() { final int rowCount = 4; for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) { - if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.STRUCT) { + if (typeEnum == DType.DTypeEnum.EMPTY || typeEnum == DType.DTypeEnum.LIST || typeEnum == DType.DTypeEnum.STRUCT) { continue; } DType dType; @@ -1428,6 +1428,50 @@ void testFromScalarNullByte() { } } + @Test + void testFromScalarNullList() { + final int rowCount = 4; + for (DType.DTypeEnum typeEnum : DType.DTypeEnum.values()) { + DType dType = typeEnum.isDecimalType() ? DType.create(typeEnum, -8): DType.create(typeEnum); + DataType hDataType; + if (DType.EMPTY.equals(dType)) { + continue; + } else if (DType.LIST.equals(dType)) { + // list of list of int32 + hDataType = new ListType(true, new BasicType(true, DType.INT32)); + } else if (DType.STRUCT.equals(dType)) { + // list of struct of int32 + hDataType = new StructType(true, new BasicType(true, DType.INT32)); + } else { + // list of non nested type + hDataType = new BasicType(true, dType); + } + try (Scalar s = Scalar.listFromNull(hDataType); + ColumnVector c = ColumnVector.fromScalar(s, rowCount); + HostColumnVector hc = c.copyToHost()) { + assertEquals(DType.LIST, c.getType()); + assertEquals(rowCount, c.getRowCount()); + assertEquals(rowCount, c.getNullCount()); + for (int i = 0; i < rowCount; ++i) { + assertTrue(hc.isNull(i)); + } + + try (ColumnView child = c.getChildColumnView(0)) { + assertEquals(dType, child.getType()); + assertEquals(0L, child.getRowCount()); + assertEquals(0L, child.getNullCount()); + if (child.getType().isNestedType()) { + try (ColumnView grandson = child.getChildColumnView(0)) { + assertEquals(DType.INT32, grandson.getType()); + assertEquals(0L, grandson.getRowCount()); + assertEquals(0L, grandson.getNullCount()); + } + } + } + } + } + } + @Test void testFromScalarListOfList() { HostColumnVector.DataType childType = new HostColumnVector.ListType(true, diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index b8141d1601a..b09850bc3d9 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -18,6 +18,12 @@ package ai.rapids.cudf; +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.DataType; +import ai.rapids.cudf.HostColumnVector.ListType; +import ai.rapids.cudf.HostColumnVector.StructData; +import ai.rapids.cudf.HostColumnVector.StructType; + import org.junit.jupiter.api.Test; import java.math.BigDecimal; @@ -56,12 +62,42 @@ public void testNull() { } else { type = DType.create(dataType); } - if (!type.isNestedType() || type.equals(DType.LIST)) { + if (!type.isNestedType()) { try (Scalar s = Scalar.fromNull(type)) { assertEquals(type, s.getType()); assertFalse(s.isValid(), "null validity for " + type); } } + + // list scalar + HostColumnVector.DataType hDataType; + if (DType.EMPTY.equals(type)) { + continue; + } else if (DType.LIST.equals(type)) { + // list of list of int32 + hDataType = new ListType(true, new BasicType(true, DType.INT32)); + } else if (DType.STRUCT.equals(type)) { + // list of struct of int32 + hDataType = new StructType(true, new BasicType(true, DType.INT32)); + } else { + // list of non nested type + hDataType = new BasicType(true, type); + } + try (Scalar s = Scalar.listFromNull(hDataType); + ColumnView listCv = s.getListAsColumnView()) { + assertFalse(s.isValid(), "null validity for " + type); + assertEquals(DType.LIST, s.getType()); + assertEquals(type, listCv.getType()); + assertEquals(0L, listCv.getRowCount()); + assertEquals(0L, listCv.getNullCount()); + if (type.isNestedType()) { + try (ColumnView child = listCv.getChildColumnView(0)) { + assertEquals(DType.INT32, child.getType()); + assertEquals(0L, child.getRowCount()); + assertEquals(0L, child.getNullCount()); + } + } + } } }