Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java support on explode_outer [skip ci] #7625

Merged
merged 1 commit into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 117 additions & 24 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,10 @@ private static native long[] repeatColumnCount(long tableHandle,

private static native long[] explodePosition(long tableHandle, int index);

private static native long[] explodeOuter(long tableHandle, int index);

private static native long[] explodeOuterPosition(long tableHandle, int index);

private static native long createCudfTableView(long[] nativeColumnViewHandles);

private static native long[] columnViewsFromPacked(ByteBuffer metadata, long dataAddress);
Expand Down Expand Up @@ -1724,7 +1728,7 @@ public ContiguousTable[] contiguousSplit(int... indices) {
* Example:
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* [[30], 300]
* index: 0
* output: [5, 100],
* [10, 100],
Expand All @@ -1736,12 +1740,12 @@ public ContiguousTable[] contiguousSplit(int... indices) {
*
* Nulls propagate in different ways depending on what is null.
* <code>
* [[5,null,15], 100],
* [null, 200]
* returns:
* [5, 100],
* [null, 100],
* [15, 100]
* input: [[5,null,15], 100],
* [null, 200]
* index: 0
* output: [5, 100],
* [null, 100],
* [15, 100]
* </code>
* Note that null lists are completely removed from the output
* and nulls inside lists are pulled out and remain.
Expand All @@ -1762,27 +1766,26 @@ public Table explode(int index) {
* in the output. The corresponding rows for other columns in the input are duplicated. A position
* column is added that has the index inside the original list for each row. Example:
* <code>
* [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* returns
* [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300],
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300]
* index: 0
* output: [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300]
* </code>
*
* Nulls and empty lists propagate in different ways depending on what is null or empty.
* <code>
* [[5,null,15], 100],
* [null, 200],
* [[], 300],
* returns
* [0, 5, 100],
* [1, null, 100],
* [2, 15, 100],
* input: [[5,null,15], 100],
* [null, 200]
* index: 0
* output: [5, 100],
* [null, 100],
* [15, 100]
* </code>
*
* Note that null lists are not included in the resulting table, but nulls inside
Expand All @@ -1798,6 +1801,96 @@ public Table explodePosition(int index) {
return new Table(explodePosition(nativeHandle, index));
}

/**
* Explodes a list column's elements.
*
* Any list is exploded, which means the elements of the list in each row are expanded
* into new rows in the output. The corresponding rows for other columns in the input
* are duplicated.
*
* <code>
* Example:
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* index: 0
* output: [5, 100],
* [10, 100],
* [15, 100],
* [20, 200],
* [25, 200],
* [30, 300]
* </code>
*
* Nulls propagate in different ways depending on what is null.
* <code>
* input: [[5,null,15], 100],
* [null, 200]
* index: 0
* output: [5, 100],
* [null, 100],
* [15, 100],
* [null, 200]
* </code>
* Note that null lists are completely removed from the output
* and nulls inside lists are pulled out and remain.
*
* @param index Column index to explode inside the table.
* @return A new table with explode_col exploded.
*/
public Table explodeOuter(int index) {
assert 0 <= index && index < columns.length : "Column index is out of range";
assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST";
return new Table(explodeOuter(nativeHandle, index));
}

/**
* Explodes a list column's elements retaining any null entries or empty lists and includes a
* position column.
*
* Any list is exploded, which means the elements of the list in each row are expanded into new rows
* in the output. The corresponding rows for other columns in the input are duplicated. A position
* column is added that has the index inside the original list for each row. Example:
*
* <code>
* Example:
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* index: 0
* output: [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300]
* </code>
*
* Nulls and empty lists propagate as null entries in the result.
* <code>
* input: [[5,null,15], 100],
* [null, 200],
* [[], 300]
* index: 0
* output: [0, 5, 100],
* [1, null, 100],
* [2, 15, 100],
* [0, null, 200],
* [0, null, 300]
* </code>
*
* returns
*
* @param index Column index to explode inside the table.
* @return A new table with exploded value and position. The column order of return table is
* [cols before explode_input, explode_position, explode_value, cols after explode_input].
*/
public Table explodeOuterPosition(int index) {
assert 0 <= index && index < columns.length : "Column index is out of range";
assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST";
return new Table(explodeOuterPosition(nativeHandle, index));
}

/**
* Gathers the rows of this table according to `gatherMap` such that row "i"
* in the resulting table's columns will contain row "gatherMap[i]" from this table.
Expand Down
29 changes: 29 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/io/orc.hpp>
#include <cudf/io/parquet.hpp>
#include <cudf/join.hpp>
#include <cudf/lists/explode.hpp>
#include <cudf/merge.hpp>
#include <cudf/partitioning.hpp>
#include <cudf/reshape.hpp>
Expand Down Expand Up @@ -2046,4 +2047,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodePosition(JNIEnv *e
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuter(JNIEnv *env, jclass,
jlong input_jtable,
jint column_index) {
JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input_table = reinterpret_cast<cudf::table_view *>(input_jtable);
cudf::size_type col_index = static_cast<cudf::size_type>(column_index);
std::unique_ptr<cudf::table> exploded = cudf::explode_outer(*input_table, col_index);
return cudf::jni::convert_table_for_return(env, exploded);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuterPosition(JNIEnv *env, jclass,
jlong input_jtable,
jint column_index) {
JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input_table = reinterpret_cast<cudf::table_view *>(input_jtable);
cudf::size_type col_index = static_cast<cudf::size_type>(column_index);
std::unique_ptr<cudf::table> exploded = cudf::explode_outer_position(*input_table, col_index);
return cudf::jni::convert_table_for_return(env, exploded);
}
CATCH_STD(env, 0);
}

} // extern "C"
86 changes: 73 additions & 13 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4585,7 +4585,7 @@ private Table[] buildExplodeTestTableWithPrimitiveTypes(boolean pos, boolean out
}
}

private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) {
private Table[] buildExplodeTestTableWithNestedTypes(boolean pos, boolean outer) {
StructType nestedType = new StructType(true,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
try (Table input = new Table.TestBuilder()
Expand All @@ -4594,23 +4594,42 @@ private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) {
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")),
Arrays.asList(new HostColumnVector.StructData((List) null)),
Arrays.asList())
null)
.column("s1", "s2", "s3", "s4", "s5")
.column(1, 3, 5, 7, 9)
.column(12.0, 14.0, 13.0, 11.0, 15.0)
.build()) {
Table.TestBuilder expectedBuilder = new Table.TestBuilder();
if (pos) {
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0);
if (!outer)
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0);
else
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0, 0);
}
try (Table expected = expectedBuilder
.column(nestedType,
List<Object[]> expectedData = new ArrayList<Object[]>(){{
if (!outer) {
this.add(new HostColumnVector.StructData[]{
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), struct(5, "k5"), struct(6, "k6"),
new HostColumnVector.StructData((List) null)});
this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4"});
this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7});
this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0});
} else {
this.add(new HostColumnVector.StructData[]{
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), struct(5, "k5"), struct(6, "k6"),
new HostColumnVector.StructData((List) null))
.column("s1", "s1", "s1", "s2", "s2", "s3", "s4")
.column(1, 1, 1, 3, 3, 5, 7)
.column(12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0)
new HostColumnVector.StructData((List) null), null});
this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4", "s5"});
this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7, 9});
this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0, 15.0});
}
}};
try (Table expected = expectedBuilder
.column(nestedType, (HostColumnVector.StructData[]) expectedData.get(0))
.column((String[]) expectedData.get(1))
.column((Integer[]) expectedData.get(2))
.column((Double[]) expectedData.get(3))
.build()) {
return new Table[]{new Table(input.getColumns()), new Table(expected.getColumns())};
}
Expand All @@ -4629,7 +4648,7 @@ void testExplode() {
}

// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false);
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false, false);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explode(0)) {
Expand All @@ -4639,7 +4658,7 @@ void testExplode() {
}

@Test
void testPosExplode() {
void testExplodePosition() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, false);
try (Table input = testTables[0];
Expand All @@ -4649,8 +4668,8 @@ void testPosExplode() {
}
}

// Child is primitive type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true);
// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true, false);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodePosition(0)) {
Expand All @@ -4659,4 +4678,45 @@ void testPosExplode() {
}
}

@Test
void testExplodeOuter() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(false, true);
try (Table input = testTables[0];
Table expected = testTables[1]) {
try (Table exploded = input.explodeOuter(0)) {
assertTablesAreEqual(expected, exploded);
}
}

// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false, true);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodeOuter(0)) {
assertTablesAreEqual(expected, exploded);
}
}
}

@Test
void testExplodeOuterPosition() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, true);
try (Table input = testTables[0];
Table expected = testTables[1]) {
try (Table exploded = input.explodeOuterPosition(0)) {
assertTablesAreEqual(expected, exploded);
}
}

// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true, true);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodeOuterPosition(0)) {
assertTablesAreEqual(expected, exploded);
}
}
}
}