Skip to content

Commit

Permalink
Add sample JNI API (#9728)
Browse files Browse the repository at this point in the history
Add sample JNI

Signed-off-by: Chong Gao <[email protected]>

Authors:
  - Chong Gao (https://github.com/res-life)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #9728
  • Loading branch information
res-life authored Dec 2, 2021
1 parent c10966c commit 582cc6e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
30 changes: 30 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ private static native ContiguousTable[] contiguousSplitGroups(long inputTable,
boolean[] keysDescending,
boolean[] keysNullSmallest);

private static native long[] sample(long tableHandle, long n, boolean replacement, long seed);

/////////////////////////////////////////////////////////////////////////////
// TABLE CREATION APIs
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2801,6 +2803,34 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
return result;
}


/**
* Gather `n` samples from table randomly
* Note: does not preserve the ordering
* Example:
* input: {col1: {1, 2, 3, 4, 5}, col2: {6, 7, 8, 9, 10}}
* n: 3
* replacement: false
*
* output: {col1: {3, 1, 4}, col2: {8, 6, 9}}
*
* replacement: true
*
* output: {col1: {3, 1, 1}, col2: {8, 6, 6}}
*
* throws "logic_error" if `n` > table rows and `replacement` == FALSE.
* throws "logic_error" if `n` < 0.
*
* @param n non-negative number of samples expected from table
* @param replacement Allow or disallow sampling of the same row more than once.
* @param seed Seed value to initiate random number generator.
*
* @return Table containing samples
*/
public Table sample(long n, boolean replacement, long seed) {
return new Table(sample(nativeHandle, n, replacement, seed));
}

/////////////////////////////////////////////////////////////////////////////
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////
Expand Down
15 changes: 15 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <arrow/ipc/api.h>
#include <cudf/aggregation.hpp>
#include <cudf/concatenate.hpp>
#include <cudf/copying.hpp>
#include <cudf/filling.hpp>
#include <cudf/groupby.hpp>
#include <cudf/hashing.hpp>
Expand Down Expand Up @@ -3147,4 +3148,18 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplitGroups(
CATCH_STD(env, NULL);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_sample(JNIEnv *env, jclass, jlong j_input,
jlong n, jboolean replacement,
jlong seed) {
JNI_NULL_CHECK(env, j_input, "input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input = reinterpret_cast<cudf::table_view *>(j_input);
auto sample_with_replacement =
replacement ? cudf::sample_with_replacement::TRUE : cudf::sample_with_replacement::FALSE;
std::unique_ptr<cudf::table> result = cudf::sample(*input, n, sample_with_replacement, seed);
return cudf::jni::convert_table_for_return(env, result);
}
CATCH_STD(env, 0);
}
} // extern "C"
21 changes: 21 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7584,4 +7584,25 @@ void testExplodeOuterPosition() {
}
}
}

@Test
void testSample() {
try (Table t = new Table.TestBuilder().column("s1", "s2", "s3", "s4", "s5").build()) {
try (Table ret = t.sample(3, false, 0);
Table expected = new Table.TestBuilder().column("s3", "s4", "s5").build()) {
assertTablesAreEqual(expected, ret);
}

try (Table ret = t.sample(5, false, 0);
Table expected = new Table.TestBuilder().column("s3", "s4", "s5", "s2", "s1").build()) {
assertTablesAreEqual(expected, ret);
}

try (Table ret = t.sample(8, true, 0);
Table expected = new Table.TestBuilder()
.column("s1", "s1", "s4", "s5", "s5", "s1", "s3", "s2").build()) {
assertTablesAreEqual(expected, ret);
}
}
}
}

0 comments on commit 582cc6e

Please sign in to comment.