Skip to content

Commit

Permalink
Add sample JNI
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
Chong Gao committed Nov 19, 2021
1 parent 079af45 commit 3b37f06
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
15 changes: 15 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,8 @@ private static native ContiguousTable[] contiguousSplitGroups(long inputTable,
boolean[] keysDescending,
boolean[] keysNullSmallest);

private static native long[] sample(long tableHandle, long n, boolean replacement, long seed);

/////////////////////////////////////////////////////////////////////////////
// TABLE CREATION APIs
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2743,6 +2745,19 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
return result;
}

/**
* Gather `n` samples from table randomly
* The output is not same with CPU Sample exec, but this is faster.
*
* @param n
* @param replacement Allow or disallow sampling of the same row more than once.
* @param seed Seed value to initiate random number generator.
* @return
*/
public Table sample(long n, boolean replacement, long seed) {
return new Table(sample(nativeHandle, n, replacement, seed));
}

/////////////////////////////////////////////////////////////////////////////
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <cudf/sorting.hpp>
#include <cudf/stream_compaction.hpp>
#include <cudf/types.hpp>
#include <cudf/copying.hpp>
#include <rmm/cuda_stream_view.hpp>

#include "cudf_jni_apis.hpp"
Expand Down Expand Up @@ -3145,4 +3146,17 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplitGroups(
CATCH_STD(env, NULL);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_sample(JNIEnv *env, jclass, jlong j_input,
jlong n, jboolean replacement, jlong seed) {
JNI_NULL_CHECK(env, j_input, "input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input = reinterpret_cast<cudf::table_view *>(j_input);
auto sample_with_replacement =
replacement ? cudf::sample_with_replacement::TRUE : cudf::sample_with_replacement::FALSE;
std::unique_ptr<cudf::table> result = cudf::sample(*input, n, sample_with_replacement, seed);
return cudf::jni::convert_table_for_return(env, result);
}
CATCH_STD(env, 0);
}
} // extern "C"

0 comments on commit 3b37f06

Please sign in to comment.