From 582cc6e466c7d941e1b34893fd56fbd42fe90d68 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 2 Dec 2021 21:12:01 +0800 Subject: [PATCH] Add sample JNI API (#9728) Add sample JNI Signed-off-by: Chong Gao Authors: - Chong Gao (https://github.com/res-life) Approvers: - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/9728 --- java/src/main/java/ai/rapids/cudf/Table.java | 30 +++++++++++++++++++ java/src/main/native/src/TableJni.cpp | 15 ++++++++++ .../test/java/ai/rapids/cudf/TableTest.java | 21 +++++++++++++ 3 files changed, 66 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index b0791fb440f..b11808ed023 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -678,6 +678,8 @@ private static native ContiguousTable[] contiguousSplitGroups(long inputTable, boolean[] keysDescending, boolean[] keysNullSmallest); + private static native long[] sample(long tableHandle, long n, boolean replacement, long seed); + ///////////////////////////////////////////////////////////////////////////// // TABLE CREATION APIs ///////////////////////////////////////////////////////////////////////////// @@ -2801,6 +2803,34 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data return result; } + + /** + * Gather `n` samples from table randomly + * Note: does not preserve the ordering + * Example: + * input: {col1: {1, 2, 3, 4, 5}, col2: {6, 7, 8, 9, 10}} + * n: 3 + * replacement: false + * + * output: {col1: {3, 1, 4}, col2: {8, 6, 9}} + * + * replacement: true + * + * output: {col1: {3, 1, 1}, col2: {8, 6, 6}} + * + * throws "logic_error" if `n` > table rows and `replacement` == FALSE. + * throws "logic_error" if `n` < 0. + * + * @param n non-negative number of samples expected from table + * @param replacement Allow or disallow sampling of the same row more than once. + * @param seed Seed value to initiate random number generator. + * + * @return Table containing samples + */ + public Table sample(long n, boolean replacement, long seed) { + return new Table(sample(nativeHandle, n, replacement, seed)); + } + ///////////////////////////////////////////////////////////////////////////// // HELPER CLASSES ///////////////////////////////////////////////////////////////////////////// diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index a78d40a58f7..f3377bb002d 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -3147,4 +3148,18 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplitGroups( CATCH_STD(env, NULL); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_sample(JNIEnv *env, jclass, jlong j_input, + jlong n, jboolean replacement, + jlong seed) { + JNI_NULL_CHECK(env, j_input, "input table is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::table_view *input = reinterpret_cast(j_input); + auto sample_with_replacement = + replacement ? cudf::sample_with_replacement::TRUE : cudf::sample_with_replacement::FALSE; + std::unique_ptr result = cudf::sample(*input, n, sample_with_replacement, seed); + return cudf::jni::convert_table_for_return(env, result); + } + CATCH_STD(env, 0); +} } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index fa221e19387..0b2f56895e9 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -7584,4 +7584,25 @@ void testExplodeOuterPosition() { } } } + + @Test + void testSample() { + try (Table t = new Table.TestBuilder().column("s1", "s2", "s3", "s4", "s5").build()) { + try (Table ret = t.sample(3, false, 0); + Table expected = new Table.TestBuilder().column("s3", "s4", "s5").build()) { + assertTablesAreEqual(expected, ret); + } + + try (Table ret = t.sample(5, false, 0); + Table expected = new Table.TestBuilder().column("s3", "s4", "s5", "s2", "s1").build()) { + assertTablesAreEqual(expected, ret); + } + + try (Table ret = t.sample(8, true, 0); + Table expected = new Table.TestBuilder() + .column("s1", "s1", "s4", "s5", "s5", "s1", "s3", "s2").build()) { + assertTablesAreEqual(expected, ret); + } + } + } }