Skip to content

Commit

Permalink
Added JNI for getMapValueForKeys (#11104)
Browse files Browse the repository at this point in the history
This PR adds Java method for getting values for a list of keys

fixes #10818

Authors:
  - Raza Jafri (https://github.com/razajafri)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #11104
  • Loading branch information
razajafri authored Jun 16, 2022
1 parent ac2d9a6 commit ef6a390
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
28 changes: 28 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3278,6 +3278,19 @@ private static void assertIsSupportedMapKeyType(DType keyType) {
assert isSupportedKeyType : "Map lookup by STRUCT and LIST keys is not supported.";
}

/**
* Given a column of type List<Struct<X, Y>> and a key column of type X, return a column of type Y,
* where each row in the output column is the Y value corresponding to the X key.
* If the key is not found, the corresponding output value is null.
* @param keys the column view with keys to lookup in the column
* @return a column of values or nulls based on the lookup result
*/
public final ColumnVector getMapValue(ColumnView keys) {
assert type.equals(DType.LIST) : "column type must be a LIST";
assert keys != null : "Lookup key may not be null";
return new ColumnVector(mapLookupForKeys(getNativeView(), keys.getNativeView()));
}

/**
* Given a column of type List<Struct<X, Y>> and a key of type X, return a column of type Y,
* where each row in the output column is the Y value corresponding to the X key.
Expand Down Expand Up @@ -3913,6 +3926,21 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
private static native long mapLookup(long columnView, long key) throws CudfException;

/**
* Native method for map lookup over a column of List<Struct<String,String>>
* The lookup column must have as many rows as the map column,
* and must match the key-type of the map.
* A column of values is returned, with the same number of rows as the map column.
* If a key is repeated in a map row, the value corresponding to the last matching
* key is returned.
* If a lookup key is null or not found, the corresponding value is null.
* @param columnView the column view handle of the map
* @param keys the column view holding the keys
* @return a column of values corresponding the value of the lookup key.
* @throws CudfException
*/
private static native long mapLookupForKeys(long columnView, long keys) throws CudfException;

/**
* Native method for check the existence of a key over a column of List<Struct<String,String>>
* @param columnView the column view handle of the map
Expand Down
15 changes: 15 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplace(JNIEnv *env
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookupForKeys(JNIEnv *env, jclass,
jlong map_column_view,
jlong lookup_keys) {
JNI_NULL_CHECK(env, map_column_view, "column is null", 0);
JNI_NULL_CHECK(env, lookup_keys, "lookup key is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const *cv = reinterpret_cast<cudf::column_view *>(map_column_view);
auto const *column_keys = reinterpret_cast<cudf::column_view *>(lookup_keys);
auto const maps_view = cudf::jni::maps_column_view{*cv};
return release_as_jlong(maps_view.get_values_for(*column_keys));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jclass,
jlong map_column_view,
jlong lookup_key) {
Expand Down
15 changes: 15 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5864,6 +5864,21 @@ void testStructChildValidity() {
}
}

@Test
void testGetMapValueForKeys() {
List<HostColumnVector.StructData> list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 2)));
List<HostColumnVector.StructData> list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(2, 3)));
List<HostColumnVector.StructData> list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(5, 4)));
HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.INT32),
new HostColumnVector.BasicType(true, DType.INT32)));
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3);
ColumnVector lookupKey = ColumnVector.fromInts(1, 6, 5);
ColumnVector res = cv.getMapValue(lookupKey);
ColumnVector expected = ColumnVector.fromBoxedInts(2, null, 4)) {
assertColumnsAreEqual(expected, res);
}
}

@Test
void testGetMapValueForInteger() {
List<HostColumnVector.StructData> list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(1, 2)));
Expand Down

0 comments on commit ef6a390

Please sign in to comment.