Skip to content

Commit

Permalink
provide Java Api for concat_list_by_rows
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed May 6, 2021
1 parent 4715c83 commit accd7f4
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 1 deletion.
41 changes: 41 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,47 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col
return new ColumnVector(stringConcatenation(column_views, separator.getScalarHandle(), narep.getScalarHandle()));
}

/**
* Concatenate columns of lists horizontally (row by row), combining a corresponding row
* from each column into a single list row of a new column. Null list element of input columns
* will be ignored (skipped) during the concatenation.
*
* @param columns array of columns containing lists, must be more than 2 columns
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(ColumnView[] columns) {
return listConcatenateByRow(columns, false);
}

/**
* Concatenate columns of lists horizontally (row by row), combining a corresponding row
* from each column into a single list row of a new column.
*
* @param columns array of columns containing lists, must be more than 2 columns
* @param ignoreNull whether to ignore null list element of input columns: If true, null list
* will be ignored from concatenation; Otherwise, any concatenation involving
* a null list element will result in a null list
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(ColumnView[] columns, boolean ignoreNull) {
assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns";
long size = columns[0].getRowCount();
long[] columnViews = new long[columns.length];
DType childType = columns[0].getChildColumnView(0).getType();
assert !childType.isNestedType() : "listConcatenateByRow only supports lists with depth 1";

for(int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getType().equals(DType.LIST) : "All columns must be of type list";
assert columns[i].getRowCount() == size : "Row count mismatch";
assert childType.equals(columns[i].getChildColumnView(0).getType()) : "Element type mismatch";

columnViews[i] = columns[i].getNativeView();
}

return new ColumnVector(listConcatenationByRow(columnViews, ignoreNull));
}

/**
* Create a new vector containing the MD5 hash of each row in the table.
*
Expand Down
12 changes: 11 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2837,7 +2837,7 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

/**
* Native method to concatenate columns of strings together, combining a row from
* each colunm into a single string.
* each column into a single string.
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @param separator string scalar inserted between each string being merged, may not be null.
* @param narep string scalar indicating null behavior. If set to null and any string in the row is null
Expand All @@ -2849,6 +2849,16 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
protected static native long stringConcatenation(long[] columnViews, long separator, long narep);


/**
* Native method to concatenate columns of lists together, combining a row from
* each column into a single list.
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the listConcatenateByRow method.
*/
protected static native long listConcatenationByRow(long[] columnViews, boolean ignoreNull);

/**
* Native method for map lookup over a column of List<Struct<String,String>>
* @param columnView the column view handle of the map
Expand Down
23 changes: 23 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cudf/datetime.hpp>
#include <cudf/filling.hpp>
#include <cudf/hashing.hpp>
#include <cudf/lists/concatenate_rows.hpp>
#include <cudf/lists/count_elements.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/extract.hpp>
Expand Down Expand Up @@ -1049,6 +1050,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation(
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listConcatenationByRow(JNIEnv *env,
jobject j_object,
jlongArray column_handles,
jboolean ignore_null) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE
: cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW;
cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });

std::unique_ptr<cudf::column> result =
cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass,
jlong lhs_view, jlong rhs_view,
jint int_op, jint out_dtype,
Expand Down
133 changes: 133 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2058,6 +2058,139 @@ void testStringConcatSeparators() {
}
}

@Test
void testListConcatByRow() {
try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList());
ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(1, 2, 3), Arrays.asList((Integer) null), Arrays.asList(10, 12), Arrays.asList(100, 200, 300),
Arrays.asList());
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2});
ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(0, 1, 2, 3), Arrays.asList(1, 2, 3, null), null, Arrays.asList(100, 200, 300),
Arrays.asList())) {
assertColumnsAreEqual(expect, result);
}

try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), Arrays.asList("X"),
Arrays.asList());
ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList(), Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null),
Arrays.asList());
ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), Arrays.asList("Z"),
Arrays.asList());
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3});
ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), null,
Arrays.asList("X", "Y", null, "Z"), Arrays.asList())) {
assertColumnsAreEqual(expect, result);
}

try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.FLOAT64)),
Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null));
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv});
ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.FLOAT64)),
Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN),
Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) {
assertColumnsAreEqual(expect, result);
}

assertThrows(AssertionError.class, () -> {
try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3);
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) {
}
}, "All columns must be of type list");

assertThrows(AssertionError.class, () -> {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3));
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv})) {
}
}, "listConcatenateByRow requires at least 2 columns");

assertThrows(AssertionError.class, () -> {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32))), Arrays.asList(Arrays.asList(1)));
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) {
}
}, "listConcatenateByRow only supports lists with depth 1");

assertThrows(AssertionError.class, () -> {
try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3));
ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2), Arrays.asList(3));
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) {
}
}, "Row count mismatch");

assertThrows(AssertionError.class, () -> {
try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3));
ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT64)), Arrays.asList(1L));
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) {
}
}, "Element type mismatch");
}

@Test
void testListConcatByRowIgnoreNull() {
try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList((Integer) null), Arrays.asList(1, 2, 3), null, Arrays.asList(), null);
ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(1, 2, 3), null, Arrays.asList(10, 12), Arrays.asList(100, 200, 300), null);
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}, true);
ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(null, 1, 2, 3), Arrays.asList(1, 2, 3), Arrays.asList(10, 12),
Arrays.asList(100, 200, 300), null)) {
assertColumnsAreEqual(expect, result);
}

try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), null, null);
ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
null, Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), null);
ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), null, null);
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}, true);
ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.STRING)),
Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"),
Arrays.asList("111", "222", "333"), Arrays.asList("Y", null), null)) {
assertColumnsAreEqual(expect, result);
}

try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.FLOAT64)),
Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null));
ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}, true);
ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.FLOAT64)),
Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN),
Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) {
assertColumnsAreEqual(expect, result);
}
}

@Test
void testPrefixSum() {
try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);
Expand Down

0 comments on commit accd7f4

Please sign in to comment.