Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support listConcatenateByRows in Java package [skip ci] #8171

Merged
merged 4 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -514,19 +514,50 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col
assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar";
assert narep != null : "narep scalar provided may not be null";
assert narep.getType().equals(DType.STRING) : "narep scalar must be a string scalar";
long size = columns[0].getRowCount();
long[] column_views = new long[columns.length];

long[] column_views = new long[columns.length];
for(int i = 0; i < columns.length; i++) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getType().equals(DType.STRING) : "All columns must be of type string for .cat() operation";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must have the same number of rows";
column_views[i] = columns[i].getNativeView();
}

return new ColumnVector(stringConcatenation(column_views, separator.getScalarHandle(), narep.getScalarHandle()));
}

/**
* Concatenate columns of lists horizontally (row by row), combining a corresponding row
* from each column into a single list row of a new column.
* NOTICE: Any concatenation involving a null list element will result in a null list.
*
* @param columns array of columns containing lists, must be more than 2 columns
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(ColumnView... columns) {
return listConcatenateByRow(false, columns);
}

/**
* Concatenate columns of lists horizontally (row by row), combining a corresponding row
* from each column into a single list row of a new column.
*
* @param ignoreNull whether to ignore null list element of input columns: If true, null list
* will be ignored from concatenation; Otherwise, any concatenation involving
* a null list element will result in a null list
* @param columns array of columns containing lists, must be more than 2 columns
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) {
assert columns != null : "input columns should not be null";
assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns";

long[] columnViews = new long[columns.length];
for(int i = 0; i < columns.length; i++) {
columnViews[i] = columns[i].getNativeView();
}

return new ColumnVector(concatListByRow(columnViews, ignoreNull));
}

/**
* Create a new vector containing the MD5 hash of each row in the table.
*
Expand Down Expand Up @@ -669,6 +700,31 @@ private static native long makeList(long[] handles, long typeHandle, int scale,

private static native long concatenate(long[] viewHandles) throws CudfException;

/**
* Native method to concatenate columns of lists horizontally (row by row), combining a row
* from each column into a single list.
*
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the listConcatenateByRow method.
*/
private static native long concatListByRow(long[] columnViews, boolean ignoreNull);

/**
* Native method to concatenate columns of strings together, combining a row from
* each column into a single string.
*
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @param separator string scalar inserted between each string being merged, may not be null.
* @param narep string scalar indicating null behavior. If set to null and any string in the row is null
* the resulting string will be null. If not null, null values in any column will be
* replaced by the specified string. The underlying value in the string scalar may be null,
* but the object passed in may not.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the stringConcatenate method.
*/
private static native long stringConcatenation(long[] columnViews, long separator, long narep);

/**
* Native method to hash each row of the given table. Hashing function dispatched on the
* native side using the hashId.
Expand Down
14 changes: 0 additions & 14 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2835,20 +2835,6 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long urlEncode(long cudfViewHandle);

/**
* Native method to concatenate columns of strings together, combining a row from
* each colunm into a single string.
* @param columnViews array of longs holding the native handles of the column_views to combine.
* @param separator string scalar inserted between each string being merged, may not be null.
* @param narep string scalar indicating null behavior. If set to null and any string in the row is null
* the resulting string will be null. If not null, null values in any column will be
* replaced by the specified string. The underlying value in the string scalar may be null,
* but the object passed in may not.
* @return native handle of the resulting cudf column, used to construct the Java column
* by the stringConcatenate method.
*/
protected static native long stringConcatenation(long[] columnViews, long separator, long narep);

/**
* Native method for map lookup over a column of List<Struct<String,String>>
* @param columnView the column view handle of the map
Expand Down
52 changes: 52 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
#include <cudf/reshape.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/detail/interop.hpp>
#include <cudf/lists/concatenate_rows.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/structs/structs_column_view.hpp>

#include "cudf_jni_apis.hpp"
Expand Down Expand Up @@ -172,6 +174,56 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env,
CATCH_STD(env, 0);
}


JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_stringConcatenation(JNIEnv *env, jclass,
jlongArray column_handles,
jlong separator,
jlong narep) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0);
JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0);
try {
cudf::jni::auto_set_device(env);
const auto& separator_scalar = *reinterpret_cast<cudf::string_scalar*>(separator);
const auto& narep_scalar = *reinterpret_cast<cudf::string_scalar*>(narep);

cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(),
n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });

std::unique_ptr<cudf::column> result =
cudf::strings::concatenate(cudf::table_view(column_views), separator_scalar, narep_scalar);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_concatListByRow(JNIEnv *env, jclass,
jlongArray column_handles,
jboolean ignore_null) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE
: cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW;

cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(),
n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });

std::unique_ptr<cudf::column> result =
cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
jlongArray handles,
jlong j_type,
Expand Down
24 changes: 0 additions & 24 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include <cudf/strings/attributes.hpp>
#include <cudf/strings/capitalize.hpp>
#include <cudf/strings/case.hpp>
#include <cudf/strings/combine.hpp>
#include <cudf/strings/contains.hpp>
#include <cudf/strings/convert/convert_booleans.hpp>
#include <cudf/strings/convert/convert_datetime.hpp>
Expand Down Expand Up @@ -1026,29 +1025,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation(
JNIEnv *env, jobject j_object, jlongArray column_handles, jlong separator, jlong narep) {
JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0);
JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0);
JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::string_scalar *separator_scalar = reinterpret_cast<cudf::string_scalar *>(separator);
cudf::string_scalar *narep_scalar = reinterpret_cast<cudf::string_scalar *>(narep);
cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_columns(env, column_handles);
std::vector<cudf::column_view> column_views;
std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(),
std::back_inserter(column_views),
[](auto const &p_column) { return *p_column; });
cudf::table_view *string_columns = new cudf::table_view(column_views);

std::unique_ptr<cudf::column> result =
cudf::strings::concatenate(*string_columns, *separator_scalar, *narep_scalar);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass,
jlong lhs_view, jlong rhs_view,
jint int_op, jint out_dtype,
Expand Down
Loading