From b72c79dea0340343a90ddb625cbe55773b52da6e Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 3 Feb 2022 19:31:43 -0700 Subject: [PATCH] Fix bitmask of the output for JNI of `lists::drop_list_duplicates` (#10210) Previously, the Spark-rapids plugin only needs to call `lists::drop_list_duplicates` to create map which requires the input to be non-nullable. As such, the output of the JNI is just a lists column without bitmask. When operating on nullable input lists column, it produces incorrect results. This PR fixes that. Authors: - Nghia Truong (https://github.com/ttnghia) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/10210 --- java/src/main/native/src/ColumnVectorJni.cpp | 4 +- java/src/main/native/src/ColumnViewJni.cpp | 5 +- .../java/ai/rapids/cudf/ColumnVectorTest.java | 63 ++++++++++++++++++- 3 files changed, 65 insertions(+), 7 deletions(-) diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 0e559ad0403..f01d832eb19 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -252,8 +252,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeListFromOffsets( JNI_NULL_CHECK(env, offsets_handle, "offsets_handle is null", 0) try { cudf::jni::auto_set_device(env); - auto const *child_cv = reinterpret_cast(child_handle); - auto const *offsets_cv = reinterpret_cast(offsets_handle); + auto const child_cv = reinterpret_cast(child_handle); + auto const offsets_cv = reinterpret_cast(offsets_handle); CUDF_EXPECTS(offsets_cv->type().id() == cudf::type_id::INT32, "Input offsets does not have type INT32."); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 63247eb0066..eec4a78a457 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -408,7 +408,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey JNI_NULL_CHECK(env, keys_vals_handle, "keys_vals_handle is null", 0); try { cudf::jni::auto_set_device(env); - auto const *input_cv = reinterpret_cast(keys_vals_handle); + auto const input_cv = reinterpret_cast(keys_vals_handle); CUDF_EXPECTS(input_cv->offset() == 0, "Input column has non-zero offset."); CUDF_EXPECTS(input_cv->type().id() == cudf::type_id::LIST, "Input column is not a lists column."); @@ -460,7 +460,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey auto out_structs = cudf::make_structs_column(out_child_size, std::move(out_structs_members), 0, {}); return release_as_jlong(cudf::make_lists_column(input_cv->size(), std::move(out_offsets), - std::move(out_structs), 0, {})); + std::move(out_structs), input_cv->null_count(), + cudf::copy_bitmask(*input_cv))); } CATCH_STD(env, 0); } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 8f39c3c51ce..f9c8029ed84 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4380,12 +4380,14 @@ void testDropListDuplicatesWithKeysValues() { 3, 4, 5, // list2 null, 0, 6, 6, 0, // list3 null, 6, 7, null, 7 // list 4 + // list5 (empty) ); ColumnVector inputChildVals = ColumnVector.fromBoxedInts( 10, 20, // list1 30, 40, 50, // list2 60, 70, 80, 90, 100, // list3 110, 120, 130, 140, 150 // list4 + // list5 (empty) ); ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals); ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 5, 10, 15, 15); @@ -4402,7 +4404,8 @@ void testDropListDuplicatesWithKeysValues() { 10, 20, 30, 40, 50, 100, 90, 60, - 120, 150, 140); + 120, 150, 140 + ); ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, expectedChildVals); ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 5, 8, 11, 11); @@ -4416,6 +4419,60 @@ void testDropListDuplicatesWithKeysValues() { } } + @Test + void testDropListDuplicatesWithKeysValuesNullable() { + try(ColumnVector inputChildKeys = ColumnVector.fromBoxedInts( + 1, 2, // list1 + // list2 (null) + 3, 4, 5, // list3 + null, 0, 6, 6, 0, // list4 + null, 6, 7, null, 7 // list 5 + // list6 (null) + ); + ColumnVector inputChildVals = ColumnVector.fromBoxedInts( + 10, 20, // list1 + // list2 (null) + 30, 40, 50, // list3 + 60, 70, 80, 90, 100, // list4 + 110, 120, 130, 140, 150 // list5 + // list6 (null) + ); + ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals); + ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 2, 5, 10, 15, 15); + ColumnVector tmpInputListsKeysVals = inputStructsKeysVals.makeListFromOffsets(6,inputOffsets); + ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, 1, 1, 1, null); + ColumnVector inputListsKeysVals = tmpInputListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask); + + ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( + 1, 2, // list1 + // list2 (null) + 3, 4, 5, // list3 + 0, 6, null, // list4 + 6, 7, null // list5 + // list6 (null) + ); + ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( + 10, 20, // list1 + // list2 (null) + 30, 40, 50, // list3 + 100, 90, 60, // list4 + 120, 150, 140 // list5 + // list6 (null) + ); + ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, + expectedChildVals); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 2, 5, 8, 11, 11); + ColumnVector tmpExpectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(6, + expectedOffsets); + ColumnVector expectedListsKeysVals = tmpExpectedListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask); + + ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues(); + ColumnVector sortedOutput = output.listSortRows(false, false); + ) { + assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + } + } + @SafeVarargs private static ColumnVector makeListsColumn(DType childDType, List... rows) { HostColumnVector.DataType childType = new HostColumnVector.BasicType(true, childDType); @@ -4716,7 +4773,7 @@ void testStringSplit() { Table resultSplitOnce = v.stringSplit(pattern, 1); Table resultSplitAll = v.stringSplit(pattern)) { assertTablesAreEqual(expectedSplitOnce, resultSplitOnce); - assertTablesAreEqual(expectedSplitAll, resultSplitAll); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); } } @@ -6068,7 +6125,7 @@ void testCopyWithBooleanColumnAsValidity() { } // Negative case: Mismatch in row count. - Exception x = assertThrows(CudfException.class, () -> { + Exception x = assertThrows(CudfException.class, () -> { try (ColumnVector exemplar = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); ColumnVector validity = ColumnVector.fromBoxedBooleans(F, T, F, T); ColumnVector result = exemplar.copyWithBooleanColumnAsValidity(validity)) {