Skip to content

Commit

Permalink
Fix bitmask of the output for JNI of lists::drop_list_duplicates (#…
Browse files Browse the repository at this point in the history
…10210)

Previously, the Spark-rapids plugin only needs to call `lists::drop_list_duplicates` to create map which requires the input to be non-nullable. As such, the output of the JNI is just a lists column without bitmask. When operating on nullable input lists column, it produces incorrect results.

This PR fixes that.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Jason Lowe (https://github.com/jlowe)

URL: #10210
  • Loading branch information
ttnghia authored Feb 4, 2022
1 parent 42d86c4 commit b72c79d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 7 deletions.
4 changes: 2 additions & 2 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeListFromOffsets(
JNI_NULL_CHECK(env, offsets_handle, "offsets_handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const *child_cv = reinterpret_cast<cudf::column_view const *>(child_handle);
auto const *offsets_cv = reinterpret_cast<cudf::column_view const *>(offsets_handle);
auto const child_cv = reinterpret_cast<cudf::column_view const *>(child_handle);
auto const offsets_cv = reinterpret_cast<cudf::column_view const *>(offsets_handle);
CUDF_EXPECTS(offsets_cv->type().id() == cudf::type_id::INT32,
"Input offsets does not have type INT32.");

Expand Down
5 changes: 3 additions & 2 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey
JNI_NULL_CHECK(env, keys_vals_handle, "keys_vals_handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const *input_cv = reinterpret_cast<cudf::column_view const *>(keys_vals_handle);
auto const input_cv = reinterpret_cast<cudf::column_view const *>(keys_vals_handle);
CUDF_EXPECTS(input_cv->offset() == 0, "Input column has non-zero offset.");
CUDF_EXPECTS(input_cv->type().id() == cudf::type_id::LIST,
"Input column is not a lists column.");
Expand Down Expand Up @@ -460,7 +460,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey
auto out_structs =
cudf::make_structs_column(out_child_size, std::move(out_structs_members), 0, {});
return release_as_jlong(cudf::make_lists_column(input_cv->size(), std::move(out_offsets),
std::move(out_structs), 0, {}));
std::move(out_structs), input_cv->null_count(),
cudf::copy_bitmask(*input_cv)));
}
CATCH_STD(env, 0);
}
Expand Down
63 changes: 60 additions & 3 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4380,12 +4380,14 @@ void testDropListDuplicatesWithKeysValues() {
3, 4, 5, // list2
null, 0, 6, 6, 0, // list3
null, 6, 7, null, 7 // list 4
// list5 (empty)
);
ColumnVector inputChildVals = ColumnVector.fromBoxedInts(
10, 20, // list1
30, 40, 50, // list2
60, 70, 80, 90, 100, // list3
110, 120, 130, 140, 150 // list4
// list5 (empty)
);
ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals);
ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
Expand All @@ -4402,7 +4404,8 @@ void testDropListDuplicatesWithKeysValues() {
10, 20,
30, 40, 50,
100, 90, 60,
120, 150, 140);
120, 150, 140
);
ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys,
expectedChildVals);
ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 5, 8, 11, 11);
Expand All @@ -4416,6 +4419,60 @@ void testDropListDuplicatesWithKeysValues() {
}
}

@Test
void testDropListDuplicatesWithKeysValuesNullable() {
try(ColumnVector inputChildKeys = ColumnVector.fromBoxedInts(
1, 2, // list1
// list2 (null)
3, 4, 5, // list3
null, 0, 6, 6, 0, // list4
null, 6, 7, null, 7 // list 5
// list6 (null)
);
ColumnVector inputChildVals = ColumnVector.fromBoxedInts(
10, 20, // list1
// list2 (null)
30, 40, 50, // list3
60, 70, 80, 90, 100, // list4
110, 120, 130, 140, 150 // list5
// list6 (null)
);
ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals);
ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 2, 5, 10, 15, 15);
ColumnVector tmpInputListsKeysVals = inputStructsKeysVals.makeListFromOffsets(6,inputOffsets);
ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, 1, 1, 1, null);
ColumnVector inputListsKeysVals = tmpInputListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask);

ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts(
1, 2, // list1
// list2 (null)
3, 4, 5, // list3
0, 6, null, // list4
6, 7, null // list5
// list6 (null)
);
ColumnVector expectedChildVals = ColumnVector.fromBoxedInts(
10, 20, // list1
// list2 (null)
30, 40, 50, // list3
100, 90, 60, // list4
120, 150, 140 // list5
// list6 (null)
);
ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys,
expectedChildVals);
ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 2, 5, 8, 11, 11);
ColumnVector tmpExpectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(6,
expectedOffsets);
ColumnVector expectedListsKeysVals = tmpExpectedListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask);

ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues();
ColumnVector sortedOutput = output.listSortRows(false, false);
) {
assertColumnsAreEqual(expectedListsKeysVals, sortedOutput);
}
}

@SafeVarargs
private static <T> ColumnVector makeListsColumn(DType childDType, List<T>... rows) {
HostColumnVector.DataType childType = new HostColumnVector.BasicType(true, childDType);
Expand Down Expand Up @@ -4716,7 +4773,7 @@ void testStringSplit() {
Table resultSplitOnce = v.stringSplit(pattern, 1);
Table resultSplitAll = v.stringSplit(pattern)) {
assertTablesAreEqual(expectedSplitOnce, resultSplitOnce);
assertTablesAreEqual(expectedSplitAll, resultSplitAll);
assertTablesAreEqual(expectedSplitAll, resultSplitAll);
}
}

Expand Down Expand Up @@ -6068,7 +6125,7 @@ void testCopyWithBooleanColumnAsValidity() {
}

// Negative case: Mismatch in row count.
Exception x = assertThrows(CudfException.class, () -> {
Exception x = assertThrows(CudfException.class, () -> {
try (ColumnVector exemplar = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
ColumnVector validity = ColumnVector.fromBoxedBooleans(F, T, F, T);
ColumnVector result = exemplar.copyWithBooleanColumnAsValidity(validity)) {
Expand Down

0 comments on commit b72c79d

Please sign in to comment.