Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bitmask of the output for JNI of lists::drop_list_duplicates #10210

Merged
17 changes: 16 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ public ColumnVector makeListFromOffsets(long rows, ColumnView offsets) {
return new ColumnVector(makeListFromOffsets(getNativeView(), offsets.getNativeView(), rows));
}

/**
/**
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* Create a new vector of length rows, starting at the initialValue and going by step each time.
* Only numeric types are supported.
* @param initialValue the initial value to start at.
Expand Down Expand Up @@ -827,6 +827,19 @@ public ColumnVector castTo(DType type) {
return super.castTo(type);
}

/**
* Create a new column that has data copied from the current column and a new bitmask copied from
* the given templateBitmask column.
*
* The caller is responsible for any data corruption that is make by calling to this function.
*
* @param templateBitmask a column from which the bitmask will be copied to the output column.
*/
public ColumnVector replaceBitmask(ColumnView templateBitmask) {
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
assert templateBitmask.getRowCount() == getRowCount();
return new ColumnVector(replaceBitmask(getNativeView(), templateBitmask.getNativeView()));
}

/////////////////////////////////////////////////////////////////////////////
// NATIVE METHODS
/////////////////////////////////////////////////////////////////////////////
Expand All @@ -848,6 +861,8 @@ private static native long makeList(long[] handles, long typeHandle, int scale,
private static native long makeListFromOffsets(long childHandle, long offsetsHandle, long rows)
throws CudfException;

private static native long replaceBitmask(long nativeHandle, long templateBitmaskHandle);

private static native long concatenate(long[] viewHandles) throws CudfException;

/**
Expand Down
22 changes: 20 additions & 2 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeListFromOffsets(
JNI_NULL_CHECK(env, offsets_handle, "offsets_handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const *child_cv = reinterpret_cast<cudf::column_view const *>(child_handle);
auto const *offsets_cv = reinterpret_cast<cudf::column_view const *>(offsets_handle);
auto const child_cv = reinterpret_cast<cudf::column_view const *>(child_handle);
auto const offsets_cv = reinterpret_cast<cudf::column_view const *>(offsets_handle);
CUDF_EXPECTS(offsets_cv->type().id() == cudf::type_id::INT32,
"Input offsets does not have type INT32.");

Expand All @@ -264,6 +264,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeListFromOffsets(
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_replaceBitmask(
JNIEnv *env, jobject j_object, jlong native_handle, jlong template_bitmask_handle) {
JNI_NULL_CHECK(env, native_handle, "native_handle is null", 0)
JNI_NULL_CHECK(env, template_bitmask_handle, "template_bitmask_handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const input_cv = reinterpret_cast<cudf::column_view const *>(native_handle);
auto const template_bitmask_cv =
reinterpret_cast<cudf::column_view const *>(template_bitmask_handle);

auto result = std::make_unique<cudf::column>(*input_cv);
result->set_null_mask(cudf::copy_bitmask(*template_bitmask_cv),
template_bitmask_cv->null_count());
return release_as_jlong(result);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env, jclass,
jlong j_scalar,
jint row_count) {
Expand Down
5 changes: 3 additions & 2 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey
JNI_NULL_CHECK(env, keys_vals_handle, "keys_vals_handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const *input_cv = reinterpret_cast<cudf::column_view const *>(keys_vals_handle);
auto const input_cv = reinterpret_cast<cudf::column_view const *>(keys_vals_handle);
CUDF_EXPECTS(input_cv->offset() == 0, "Input column has non-zero offset.");
CUDF_EXPECTS(input_cv->type().id() == cudf::type_id::LIST,
"Input column is not a lists column.");
Expand Down Expand Up @@ -460,7 +460,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey
auto out_structs =
cudf::make_structs_column(out_child_size, std::move(out_structs_members), 0, {});
return release_as_jlong(cudf::make_lists_column(input_cv->size(), std::move(out_offsets),
std::move(out_structs), 0, {}));
std::move(out_structs), input_cv->null_count(),
cudf::copy_bitmask(*input_cv)));
}
CATCH_STD(env, 0);
}
Expand Down
60 changes: 58 additions & 2 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4380,12 +4380,14 @@ void testDropListDuplicatesWithKeysValues() {
3, 4, 5, // list2
null, 0, 6, 6, 0, // list3
null, 6, 7, null, 7 // list 4
// list5 (empty)
);
ColumnVector inputChildVals = ColumnVector.fromBoxedInts(
10, 20, // list1
30, 40, 50, // list2
60, 70, 80, 90, 100, // list3
110, 120, 130, 140, 150 // list4
// list5 (empty)
);
ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals);
ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
Expand Down Expand Up @@ -4416,6 +4418,60 @@ void testDropListDuplicatesWithKeysValues() {
}
}

@Test
void testDropListDuplicatesWithKeysValuesNullable() {
try(ColumnVector inputChildKeys = ColumnVector.fromBoxedInts(
1, 2, // list1
// list2 (null)
3, 4, 5, // list3
null, 0, 6, 6, 0, // list4
null, 6, 7, null, 7 // list 5
// list6 (null)
);
ColumnVector inputChildVals = ColumnVector.fromBoxedInts(
10, 20, // list1
// list2 (null)
30, 40, 50, // list3
60, 70, 80, 90, 100, // list4
110, 120, 130, 140, 150 // list5
// list6 (null)
);
ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals);
ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 2, 5, 10, 15, 15);
ColumnVector tmpInputListsKeysVals = inputStructsKeysVals.makeListFromOffsets(6,inputOffsets);
ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, 1, 1, 1, null);
ColumnVector inputListsKeysVals = tmpInputListsKeysVals.replaceBitmask(templateBitmask);

ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts(
1, 2, // list1
// list2 (null)
3, 4, 5, // list3
0, 6, null, // list4
6, 7, null // list5
// list6 (null)
);
ColumnVector expectedChildVals = ColumnVector.fromBoxedInts(
10, 20, // list1
// list2 (null)
30, 40, 50, // list3
100, 90, 60, // list4
120, 150, 140 // list5
// list6 (null)
);
ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys,
expectedChildVals);
ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 2, 5, 8, 11, 11);
ColumnVector tmpExpectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(6,
expectedOffsets);
ColumnVector expectedListsKeysVals = tmpExpectedListsKeysVals.replaceBitmask(templateBitmask);

ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues();
ColumnVector sortedOutput = output.listSortRows(false, false);
) {
assertColumnsAreEqual(expectedListsKeysVals, sortedOutput);
}
}

@SafeVarargs
private static <T> ColumnVector makeListsColumn(DType childDType, List<T>... rows) {
HostColumnVector.DataType childType = new HostColumnVector.BasicType(true, childDType);
Expand Down Expand Up @@ -4716,7 +4772,7 @@ void testStringSplit() {
Table resultSplitOnce = v.stringSplit(pattern, 1);
Table resultSplitAll = v.stringSplit(pattern)) {
assertTablesAreEqual(expectedSplitOnce, resultSplitOnce);
assertTablesAreEqual(expectedSplitAll, resultSplitAll);
assertTablesAreEqual(expectedSplitAll, resultSplitAll);
}
}

Expand Down Expand Up @@ -6068,7 +6124,7 @@ void testCopyWithBooleanColumnAsValidity() {
}

// Negative case: Mismatch in row count.
Exception x = assertThrows(CudfException.class, () -> {
Exception x = assertThrows(CudfException.class, () -> {
try (ColumnVector exemplar = ColumnVector.fromBoxedInts(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
ColumnVector validity = ColumnVector.fromBoxedBooleans(F, T, F, T);
ColumnVector result = exemplar.copyWithBooleanColumnAsValidity(validity)) {
Expand Down