Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
Add java bindings for distinct count (rapidsai#13573)
Browse files Browse the repository at this point in the history
This adds in java bindings for cudf::distinct_count for both table and column APIs.

Authors:
  - Robert (Bobby) Evans (https://github.com/revans2)

Approvers:
  - Gera Shegalov (https://github.com/gerashegalov)
  - Nghia Truong (https://github.com/ttnghia)

URL: rapidsai#13573
  • Loading branch information
revans2 authored Jun 14, 2023
1 parent c929a84 commit 649cf5e
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 0 deletions.
18 changes: 18 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3967,6 +3967,22 @@ static long getValidityBufferSize(int numRows) {
return ((actualBytes + 63) >> 6) << 6;
}

/**
* Count how many rows in the column are distinct from one another.
* @param nullPolicy if nulls should be included or not.
*/
public int distinctCount(NullPolicy nullPolicy) {
return distinctCount(getNativeView(), nullPolicy.includeNulls);
}

/**
* Count how many rows in the column are distinct from one another.
* Nulls are included.
*/
public int distinctCount() {
return distinctCount(getNativeView(), true);
}

/////////////////////////////////////////////////////////////////////////////
// INTERNAL/NATIVE ACCESS
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -3999,6 +4015,8 @@ static DeviceMemoryBufferView getOffsetsBuffer(long viewHandle) {
}

// Native Methods
private static native int distinctCount(long handle, boolean nullsIncluded);

/**
* Native method to parse and convert a string column vector to unix timestamp. A unix
* timestamp is a long value representing how many units since 1970-01-01 00:00:00.000 in either
Expand Down
18 changes: 18 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ private static native ContigSplitGroupByResult contiguousSplitGroups(long inputT

private static native long[] sample(long tableHandle, long n, boolean replacement, long seed);

private static native int distinctCount(long handle, boolean nullsEqual);

/////////////////////////////////////////////////////////////////////////////
// TABLE CREATION APIs
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2160,6 +2162,22 @@ public Table dropDuplicates(int[] keyColumns, DuplicateKeepOption keep, boolean
return new Table(dropDuplicates(nativeHandle, keyColumns, keep.keepValue, nullsEqual));
}

/**
* Count how many rows in the table are distinct from one another.
* @param nullEqual if nulls should be considered equal to each other or not.
*/
public int distinctCount(NullEquality nullsEqual) {
return distinctCount(nativeHandle, nullsEqual.nullsEqual);
}

/**
* Count how many rows in the table are distinct from one another.
* Nulls are considered to be equal to one another.
*/
public int distinctCount() {
return distinctCount(nativeHandle, true);
}

/**
* Split a table at given boundaries, but the result of each split has memory that is laid out
* in a contiguous range of memory. This allows for us to optimize copying the data in a single
Expand Down
16 changes: 16 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <cudf/round.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/search.hpp>
#include <cudf/stream_compaction.hpp>
#include <cudf/strings/attributes.hpp>
#include <cudf/strings/capitalize.hpp>
#include <cudf/strings/case.hpp>
Expand Down Expand Up @@ -178,6 +179,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsPolicy(JNIEnv
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_distinctCount(JNIEnv *env, jclass,
jlong j_col,
jboolean nulls_included) {
JNI_NULL_CHECK(env, j_col, "column is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view col = *reinterpret_cast<cudf::column_view *>(j_col);

return cudf::distinct_count(
col, nulls_included ? cudf::null_policy::INCLUDE : cudf::null_policy::EXCLUDE,
cudf::nan_policy::NAN_IS_VALID);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_ifElseVV(JNIEnv *env, jclass,
jlong j_pred_vec, jlong j_true_vec,
jlong j_false_vec) {
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2921,6 +2921,20 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_Table_distinctCount(JNIEnv *env, jclass,
jlong input_jtable,
jboolean nulls_equal) {
JNI_NULL_CHECK(env, input_jtable, "input table is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::table_view const *>(input_jtable);

return cudf::distinct_count(*input, nulls_equal ? cudf::null_equality::EQUAL :
cudf::null_equality::UNEQUAL);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(JNIEnv *env, jclass,
jlong input_jtable,
jintArray key_columns,
Expand Down
8 changes: 8 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ void testTransformVector() {
}
}

@Test
void testDistinctCount() {
try (ColumnVector cv = ColumnVector.fromBoxedLongs(5L, 3L, null, null, 5L)) {
assertEquals(3, cv.distinctCount());
assertEquals(2, cv.distinctCount(NullPolicy.EXCLUDE));
}
}

@Test
void testClampDouble() {
try (ColumnVector cv = ColumnVector.fromDoubles(2.33d, 32.12d, -121.32d, 0.0d, 0.00001d,
Expand Down
10 changes: 10 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ void assertTablesHaveSameValues(HashMap<Object, Integer>[] expectedTable, Table
}
}

@Test
void testDistinctCount() {
try (Table table1 = new Table.TestBuilder()
.column(5, 3, null, null, 5)
.build()) {
assertEquals(3, table1.distinctCount());
assertEquals(4, table1.distinctCount(NullEquality.UNEQUAL));
}
}

@Test
void testMergeSimple() {
try (Table table1 = new Table.TestBuilder()
Expand Down

0 comments on commit 649cf5e

Please sign in to comment.