Skip to content

Commit

Permalink
Add in JNI support for count_elements (#7651)
Browse files Browse the repository at this point in the history
This is just a simple JNI wrapper around cudf::lists::count_elements

Authors:
  - Robert (Bobby) Evans (@revans2)

Approvers:
  - Jason Lowe (@jlowe)

URL: #7651
  • Loading branch information
revans2 authored Mar 19, 2021
1 parent bd29a92 commit 8773a40
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
11 changes: 11 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ public final ColumnVector getByteCount() {
return new ColumnVector(byteCount(getNativeView()));
}

/**
* Get the number of elements for each list. Null lists will have a value of null.
* @return the number of elements in each list as an INT32 value.
*/
public final ColumnVector countElements() {
assert DType.LIST.equals(type) : "Only lists are supported";
return new ColumnVector(countElements(getNativeView()));
}

/**
* Returns a Boolean vector with the same number of rows as this instance, that has
* TRUE for any entry that is not null, and FALSE for any null entry (as per the validity mask)
Expand Down Expand Up @@ -2749,6 +2758,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long binaryOpVV(long lhs, long rhs, int op, int dtype, int scale);

private static native long countElements(long viewHandle);

private static native long byteCount(long viewHandle) throws CudfException;

private static native long extractListElement(long nativeView, int index);
Expand Down
20 changes: 17 additions & 3 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
#include <cudf/binaryop.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/concatenate.hpp>
#include <cudf/lists/extract.hpp>
#include <cudf/reshape.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/datetime.hpp>
#include <cudf/filling.hpp>
#include <cudf/hashing.hpp>
#include <cudf/lists/count_elements.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/extract.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/quantiles.hpp>
#include <cudf/reduction.hpp>
#include <cudf/replace.hpp>
#include <cudf/reshape.hpp>
#include <cudf/rolling.hpp>
#include <cudf/round.hpp>
#include <cudf/scalar/scalar_factories.hpp>
Expand Down Expand Up @@ -430,6 +431,19 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_split(JNIEnv *env, j
CATCH_STD(env, NULL);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_countElements(JNIEnv *env, jclass clazz,
jlong view_handle) {
JNI_NULL_CHECK(env, view_handle, "input column is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *n_column = reinterpret_cast<cudf::column_view *>(view_handle);
std::unique_ptr<cudf::column> result =
cudf::lists::count_elements(cudf::lists_column_view(*n_column));
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_charLengths(JNIEnv *env, jclass clazz,
jlong view_handle) {
JNI_NULL_CHECK(env, view_handle, "input column is null", 0);
Expand Down
12 changes: 12 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,18 @@ void testAppendStrings() {
}
}

@Test
void testCountElements() {
DataType dt = new ListType(true, new BasicType(true, DType.INT32));
try (ColumnVector cv = ColumnVector.fromLists(dt, Arrays.asList(1),
Arrays.asList(1, 2), null, Arrays.asList(null, null),
Arrays.asList(1, 2, 3), Arrays.asList(1, 2, 3, 4));
ColumnVector lengths = cv.countElements();
ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, null, 2, 3, 4)) {
TableTest.assertColumnsAreEqual(expected, lengths);
}
}

@Test
void testStringLengths() {
try (ColumnVector cv = ColumnVector.fromStrings("1", "12", null, "123", "1234");
Expand Down

0 comments on commit 8773a40

Please sign in to comment.