Skip to content

Commit

Permalink
Enable segmented_gather in Java package (#10669)
Browse files Browse the repository at this point in the history
Current PR is to enable cuDF API `segmented_gather` in Java package. `segmented_gather` is essential to implement spark array functions like `arrays_zip`(NVIDIA/spark-rapids#5229).

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #10669
  • Loading branch information
sperlingxx authored Apr 19, 2022
1 parent 6c79b59 commit 17d49fa
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
28 changes: 28 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,31 @@ public ColumnVector segmentedReduce(ColumnView offsets, SegmentedReductionAggreg
}
}

/**
* Segmented gather of the elements within a list element in each row of a list column.
* For each list, assuming the size is N, valid indices of gather map ranges in [-N, N).
* Out of bound indices refer to null.
* @param gatherMap ListColumnView carrying lists of integral indices which maps the
* element in list of each row in the source columns to rows of lists in the result columns.
* @return the result.
*/
public ColumnVector segmentedGather(ColumnView gatherMap) {
return segmentedGather(gatherMap, OutOfBoundsPolicy.NULLIFY);
}

/**
* Segmented gather of the elements within a list element in each row of a list column.
* @param gatherMap ListColumnView carrying lists of integral indices which maps the
* element in list of each row in the source columns to rows of lists in the result columns.
* @param policy OutOfBoundsPolicy, `DONT_CHECK` leads to undefined behaviour; `NULLIFY`
* replaces out of bounds with null.
* @return the result.
*/
public ColumnVector segmentedGather(ColumnView gatherMap, OutOfBoundsPolicy policy) {
return new ColumnVector(segmentedGather(getNativeView(), gatherMap.getNativeView(),
policy.equals(OutOfBoundsPolicy.NULLIFY)));
}

/**
* Do a reduction on the values in a list. The output type will be the type of the data column
* of this list.
Expand Down Expand Up @@ -3998,6 +4023,9 @@ private static native long scan(long viewHandle, long aggregation,
private static native long segmentedReduce(long dataViewHandle, long offsetsViewHandle,
long aggregation, boolean includeNulls, int dtype, int scale) throws CudfException;

private static native long segmentedGather(long sourceColumnHandle, long gatherMapListHandle,
boolean isNullifyOutBounds) throws CudfException;

private static native long isNullNative(long viewHandle);

private static native long isNanNative(long viewHandle);
Expand Down
18 changes: 18 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/drop_list_duplicates.hpp>
#include <cudf/lists/extract.hpp>
#include <cudf/lists/gather.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/lists/sorting.hpp>
#include <cudf/null_mask.hpp>
Expand Down Expand Up @@ -288,6 +289,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_segmentedReduce(
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_segmentedGather(
JNIEnv *env, jclass, jlong source_column, jlong gather_map_list, jboolean nullify_out_bounds) {
JNI_NULL_CHECK(env, source_column, "source column view is null", 0);
JNI_NULL_CHECK(env, gather_map_list, "gather map is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const &src_col =
cudf::lists_column_view(*reinterpret_cast<cudf::column_view *>(source_column));
auto const &gather_map =
cudf::lists_column_view(*reinterpret_cast<cudf::column_view *>(gather_map_list));
auto out_bounds_policy = nullify_out_bounds ? cudf::out_of_bounds_policy::NULLIFY :
cudf::out_of_bounds_policy::DONT_CHECK;
return release_as_jlong(cudf::lists::segmented_gather(src_col, gather_map, out_bounds_policy));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_scan(JNIEnv *env, jclass, jlong j_col_view,
jlong j_agg, jboolean is_inclusive,
jboolean include_nulls) {
Expand Down
25 changes: 25 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import ai.rapids.cudf.ColumnView.FindOptions;
import ai.rapids.cudf.HostColumnVector.*;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -6259,4 +6260,28 @@ void testCopyWithBooleanColumnAsValidity() {
});
assertTrue(x.getMessage().contains("Exemplar and validity columns must have the same size"));
}

@Test
void testSegmentedGather() {
HostColumnVector.DataType dt = new ListType(true, new BasicType(true, DType.STRING));
try (ColumnVector source = ColumnVector.fromLists(dt,
Lists.newArrayList("a", "b", null, "c"),
null,
Lists.newArrayList(),
Lists.newArrayList(null, "A", "B", "C", "D"));
ColumnVector gatherMap = ColumnVector.fromLists(
new ListType(false, new BasicType(false, DType.INT32)),
Lists.newArrayList(-3, 0, 2, 3, 4),
Lists.newArrayList(),
Lists.newArrayList(1),
Lists.newArrayList(1, -4, 5, -1, -6));
ColumnVector actual = source.segmentedGather(gatherMap);
ColumnVector expected = ColumnVector.fromLists(dt,
Lists.newArrayList("b", "a", null, "c", null),
null,
Lists.newArrayList((String) null),
Lists.newArrayList("A", "A", null, "D", null))) {
assertColumnsAreEqual(expected, actual);
}
}
}

0 comments on commit 17d49fa

Please sign in to comment.