Skip to content

Commit

Permalink
Support listConcatenateByRows in Java package (#8171)
Browse files Browse the repository at this point in the history
Current PR is to provide Java API for `cudf::lists::concatenate_rows`, which is added in #8049.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #8171
  • Loading branch information
sperlingxx authored May 7, 2021
1 parent db21232 commit 5f9dade
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 80 deletions.
64 changes: 60 additions & 4 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -514,19 +514,50 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col
assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar";
assert narep != null : "narep scalar provided may not be null";
assert narep.getType().equals(DType.STRING) : "narep scalar must be a string scalar";
long size = columns[0].getRowCount();
long[] column_views = new long[columns.length];

long[] column_views = new long[columns.length];
for(int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getType().equals(DType.STRING) : "All columns must be of type string for .cat() operation";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must have the same number of rows";
column_views[i] = columns[i].getNativeView();
}

return new ColumnVector(stringConcatenation(column_views, separator.getScalarHandle(), narep.getScalarHandle()));
}

/**
* Concatenate columns of lists horizontally (row by row), combining a corresponding row
* from each column into a single list row of a new column.
* NOTICE: Any concatenation involving a null list element will result in a null list.
*
* @param columns array of columns containing lists, must be more than 2 columns
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(ColumnView... columns) {
return listConcatenateByRow(false, columns);
}

/**
* Concatenate columns of lists horizontally (row by row), combining a corresponding row
* from each column into a single list row of a new column.
*
* @param ignoreNull whether to ignore null list element of input columns: If true, null list
* will be ignored from concatenation; Otherwise, any concatenation involving
* a null list element will result in a null list
* @param columns array of columns containing lists, must be more than 2 columns
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) {
assert columns != null : "input columns should not be null";
assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns";

long[] columnViews = new long[columns.length];
for(int i = 0; i < columns.length; i++) {
columnViews[i] = columns[i].getNativeView();
}

return new ColumnVector(concatListByRow(columnViews, ignoreNull));
}

/**
* Create a new vector containing the MD5 hash of each row in the table.
*
Expand Down Expand Up @@ -669,6 +700,31 @@ private static native long makeList(long[] handles, long typeHandle, int scale,

private static native long concatenate(long[] viewHandles) throws CudfException;

/**
* Native method to concatenate columns of lists horizontally (row by row), combining a row
* from each column into a single list.
*
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the listConcatenateByRow method.
*/
private static native long concatListByRow(long[] columnViews, boolean ignoreNull);

/**
* Native method to concatenate columns of strings together, combining a row from
* each column into a single string.
*
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @param separator string scalar inserted between each string being merged, may not be null.
* @param narep string scalar indicating null behavior. If set to null and any string in the row is null
* the resulting string will be null. If not null, null values in any column will be
* replaced by the specified string. The underlying value in the string scalar may be null,
* but the object passed in may not.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the stringConcatenate method.
*/
private static native long stringConcatenation(long[] columnViews, long separator, long narep);

/**
* Native method to hash each row of the given table. Hashing function dispatched on the
* native side using the hashId.
Expand Down
14 changes: 0 additions & 14 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2835,20 +2835,6 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long urlEncode(long cudfViewHandle);

/**
* Native method to concatenate columns of strings together, combining a row from
* each colunm into a single string.
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @param separator string scalar inserted between each string being merged, may not be null.
* @param narep string scalar indicating null behavior. If set to null and any string in the row is null
* the resulting string will be null. If not null, null values in any column will be
* replaced by the specified string. The underlying value in the string scalar may be null,
* but the object passed in may not.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the stringConcatenate method.
*/
protected static native long stringConcatenation(long[] columnViews, long separator, long narep);

/**
* Native method for map lookup over a column of List<Struct<String,String>>
* @param columnView the column view handle of the map
Expand Down
52 changes: 52 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
#include <cudf/reshape.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/detail/interop.hpp>
#include <cudf/lists/concatenate_rows.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/structs/structs_column_view.hpp>

#include "cudf_jni_apis.hpp"
Expand Down Expand Up @@ -172,6 +174,56 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env,
CATCH_STD(env, 0);
}


JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_stringConcatenation(JNIEnv *env, jclass,
jlongArray column_handles,
jlong separator,
jlong narep) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0);
JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0);
try {
cudf::jni::auto_set_device(env);
const auto& separator_scalar = *reinterpret_cast<cudf::string_scalar*>(separator);
const auto& narep_scalar = *reinterpret_cast<cudf::string_scalar*>(narep);

cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(),
n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });

std::unique_ptr<cudf::column> result =
cudf::strings::concatenate(cudf::table_view(column_views), separator_scalar, narep_scalar);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_concatListByRow(JNIEnv *env, jclass,
jlongArray column_handles,
jboolean ignore_null) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE
: cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW;

cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(),
n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });

std::unique_ptr<cudf::column> result =
cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
jlongArray handles,
jlong j_type,
Expand Down
24 changes: 0 additions & 24 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include <cudf/strings/attributes.hpp>
#include <cudf/strings/capitalize.hpp>
#include <cudf/strings/case.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/strings/contains.hpp>
#include <cudf/strings/convert/convert_booleans.hpp>
#include <cudf/strings/convert/convert_datetime.hpp>
Expand Down Expand Up @@ -1026,29 +1025,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation(
JNIEnv *env, jobject j_object, jlongArray column_handles, jlong separator, jlong narep) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0);
JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::string_scalar *separator_scalar = reinterpret_cast<cudf::string_scalar *>(separator);
cudf::string_scalar *narep_scalar = reinterpret_cast<cudf::string_scalar *>(narep);
cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });
cudf::table_view *string_columns = new cudf::table_view(column_views);

std::unique_ptr<cudf::column> result =
cudf::strings::concatenate(*string_columns, *separator_scalar, *narep_scalar);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass,
jlong lhs_view, jlong rhs_view,
jint int_op, jint out_dtype,
Expand Down
Loading

0 comments on commit 5f9dade

Please sign in to comment.