diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index b29b873092d..90fe3553abc 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -388,7 +388,19 @@ public final ColumnVector findAndReplaceAll(ColumnView oldValues, ColumnView new * @return - ColumnVector with nulls replaced by scalar */ public final ColumnVector replaceNulls(Scalar scalar) { - return new ColumnVector(replaceNulls(getNativeView(), scalar.getScalarHandle())); + return new ColumnVector(replaceNullsScalar(getNativeView(), scalar.getScalarHandle())); + } + + /** + * Returns a ColumnVector with any null values replaced with the corresponding row in the + * specified replacement column. + * This column and the replacement column must have the same type and number of rows. + * + * @param replacements column of replacement values + * @return column with nulls replaced by corresponding row of replacements column + */ + public final ColumnVector replaceNulls(ColumnView replacements) { + return new ColumnVector(replaceNullsColumn(getNativeView(), replacements.getNativeView())); } /** @@ -2840,7 +2852,9 @@ private static native long rollingWindow( private static native long charLengths(long viewHandle) throws CudfException; - private static native long replaceNulls(long viewHandle, long scalarHandle) throws CudfException; + private static native long replaceNullsScalar(long viewHandle, long scalarHandle) throws CudfException; + + private static native long replaceNullsColumn(long viewHandle, long replaceViewHandle) throws CudfException; private static native long ifElseVV(long predVec, long trueVec, long falseVec) throws CudfException; diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 3928794b55c..dc1acc50b5f 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -121,8 +121,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_lowerStrings(JNIEnv *env, CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNulls(JNIEnv *env, jclass, - jlong j_col, jlong j_scalar) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsScalar(JNIEnv *env, jclass, + jlong j_col, + jlong j_scalar) { JNI_NULL_CHECK(env, j_col, "column is null", 0); JNI_NULL_CHECK(env, j_scalar, "scalar is null", 0); try { @@ -135,6 +136,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNulls(JNIEnv *env, CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsColumn(JNIEnv *env, jclass, + jlong j_col, + jlong j_replace_col) { + JNI_NULL_CHECK(env, j_col, "column is null", 0); + JNI_NULL_CHECK(env, j_replace_col, "replacement column is null", 0); + try { + cudf::jni::auto_set_device(env); + auto col = reinterpret_cast(j_col); + auto replacements = reinterpret_cast(j_replace_col); + std::unique_ptr result = cudf::replace_nulls(*col, *replacements); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_ifElseVV(JNIEnv *env, jclass, jlong j_pred_vec, jlong j_true_vec, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 5a9404f5760..fe1cba5ceb1 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1368,7 +1368,7 @@ void testFromScalarNullByte() { } @Test - void testReplaceEmptyColumn() { + void testReplaceNullsScalarEmptyColumn() { try (ColumnVector input = ColumnVector.fromBoxedBooleans(); ColumnVector expected = ColumnVector.fromBoxedBooleans(); Scalar s = Scalar.fromBool(false); @@ -1378,7 +1378,7 @@ void testReplaceEmptyColumn() { } @Test - void testReplaceNullBoolsWithAllNulls() { + void testReplaceNullsScalarBoolsWithAllNulls() { try (ColumnVector input = ColumnVector.fromBoxedBooleans(null, null, null, null); ColumnVector expected = ColumnVector.fromBoxedBooleans(false, false, false, false); Scalar s = Scalar.fromBool(false); @@ -1388,7 +1388,7 @@ void testReplaceNullBoolsWithAllNulls() { } @Test - void testReplaceSomeNullBools() { + void testReplaceNullsScalarSomeNullBools() { try (ColumnVector input = ColumnVector.fromBoxedBooleans(false, null, null, false); ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, true, false); Scalar s = Scalar.fromBool(true); @@ -1398,7 +1398,7 @@ void testReplaceSomeNullBools() { } @Test - void testReplaceNullIntegersWithAllNulls() { + void testReplaceNullsScalarIntegersWithAllNulls() { try (ColumnVector input = ColumnVector.fromBoxedInts(null, null, null, null); ColumnVector expected = ColumnVector.fromBoxedInts(0, 0, 0, 0); Scalar s = Scalar.fromInt(0); @@ -1408,7 +1408,7 @@ void testReplaceNullIntegersWithAllNulls() { } @Test - void testReplaceSomeNullIntegers() { + void testReplaceNullsScalarSomeNullIntegers() { try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null); ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 999, 4, 999); Scalar s = Scalar.fromInt(999); @@ -1418,7 +1418,7 @@ void testReplaceSomeNullIntegers() { } @Test - void testReplaceNullsFailsOnTypeMismatch() { + void testReplaceNullsScalarFailsOnTypeMismatch() { try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null); Scalar s = Scalar.fromBool(true)) { assertThrows(CudfException.class, () -> input.replaceNulls(s).close()); @@ -1434,6 +1434,44 @@ void testReplaceNullsWithNullScalar() { } } + @Test + void testReplaceNullsColumnEmptyColumn() { + try (ColumnVector input = ColumnVector.fromBoxedBooleans(); + ColumnVector r = ColumnVector.fromBoxedBooleans(); + ColumnVector expected = ColumnVector.fromBoxedBooleans(); + ColumnVector result = input.replaceNulls(r)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testReplaceNullsColumnBools() { + try (ColumnVector input = ColumnVector.fromBoxedBooleans(null, true, null, false); + ColumnVector r = ColumnVector.fromBoxedBooleans(false, null, true, true); + ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, true, false); + ColumnVector result = input.replaceNulls(r)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testReplaceNullsColumnIntegers() { + try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null); + ColumnVector r = ColumnVector.fromBoxedInts(996, 997, 998, 909, null); + ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 998, 4, null); + ColumnVector result = input.replaceNulls(r)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testReplaceNullsColumnFailsOnTypeMismatch() { + try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null); + ColumnVector r = ColumnVector.fromBoxedBooleans(true)) { + assertThrows(CudfException.class, () -> input.replaceNulls(r).close()); + } + } + static QuantileMethod[] methods = {LINEAR, LOWER, HIGHER, MIDPOINT, NEAREST}; static double[] quantiles = {0.0, 0.25, 0.33, 0.5, 1.0};