Skip to content

Commit

Permalink
Java support on explode_outer (#7625)
Browse files Browse the repository at this point in the history
This pull request aims to enable `cudf::explode_outer` and `cudf::explode_outer_position` in Java package.

Authors:
  - Alfred Xu (@sperlingxx)

Approvers:
  - Robert (Bobby) Evans (@revans2)

URL: #7625
  • Loading branch information
sperlingxx authored Mar 18, 2021
1 parent 168c489 commit 99001d2
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 37 deletions.
141 changes: 117 additions & 24 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ private static native long[] repeatColumnCount(long tableHandle,

private static native long[] explodePosition(long tableHandle, int index);

private static native long[] explodeOuter(long tableHandle, int index);

private static native long[] explodeOuterPosition(long tableHandle, int index);

private static native long createCudfTableView(long[] nativeColumnViewHandles);

private static native long[] columnViewsFromPacked(ByteBuffer metadata, long dataAddress);
Expand Down Expand Up @@ -1725,7 +1729,7 @@ public ContiguousTable[] contiguousSplit(int... indices) {
* Example:
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* [[30], 300]
* index: 0
* output: [5, 100],
* [10, 100],
Expand All @@ -1737,12 +1741,12 @@ public ContiguousTable[] contiguousSplit(int... indices) {
*
* Nulls propagate in different ways depending on what is null.
* <code>
* [[5,null,15], 100],
* [null, 200]
* returns:
* [5, 100],
* [null, 100],
* [15, 100]
* input: [[5,null,15], 100],
* [null, 200]
* index: 0
* output: [5, 100],
* [null, 100],
* [15, 100]
* </code>
* Note that null lists are completely removed from the output
* and nulls inside lists are pulled out and remain.
Expand All @@ -1763,27 +1767,26 @@ public Table explode(int index) {
* in the output. The corresponding rows for other columns in the input are duplicated. A position
* column is added that has the index inside the original list for each row. Example:
* <code>
* [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* returns
* [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300],
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300]
* index: 0
* output: [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300]
* </code>
*
* Nulls and empty lists propagate in different ways depending on what is null or empty.
* <code>
* [[5,null,15], 100],
* [null, 200],
* [[], 300],
* returns
* [0, 5, 100],
* [1, null, 100],
* [2, 15, 100],
* input: [[5,null,15], 100],
* [null, 200]
* index: 0
* output: [5, 100],
* [null, 100],
* [15, 100]
* </code>
*
* Note that null lists are not included in the resulting table, but nulls inside
Expand All @@ -1799,6 +1802,96 @@ public Table explodePosition(int index) {
return new Table(explodePosition(nativeHandle, index));
}

/**
* Explodes a list column's elements.
*
* Any list is exploded, which means the elements of the list in each row are expanded
* into new rows in the output. The corresponding rows for other columns in the input
* are duplicated.
*
* <code>
* Example:
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* index: 0
* output: [5, 100],
* [10, 100],
* [15, 100],
* [20, 200],
* [25, 200],
* [30, 300]
* </code>
*
* Nulls propagate in different ways depending on what is null.
* <code>
* input: [[5,null,15], 100],
* [null, 200]
* index: 0
* output: [5, 100],
* [null, 100],
* [15, 100],
* [null, 200]
* </code>
* Note that null lists are completely removed from the output
* and nulls inside lists are pulled out and remain.
*
* @param index Column index to explode inside the table.
* @return A new table with explode_col exploded.
*/
public Table explodeOuter(int index) {
assert 0 <= index && index < columns.length : "Column index is out of range";
assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST";
return new Table(explodeOuter(nativeHandle, index));
}

/**
* Explodes a list column's elements retaining any null entries or empty lists and includes a
* position column.
*
* Any list is exploded, which means the elements of the list in each row are expanded into new rows
* in the output. The corresponding rows for other columns in the input are duplicated. A position
* column is added that has the index inside the original list for each row. Example:
*
* <code>
* Example:
* input: [[5,10,15], 100],
* [[20,25], 200],
* [[30], 300],
* index: 0
* output: [0, 5, 100],
* [1, 10, 100],
* [2, 15, 100],
* [0, 20, 200],
* [1, 25, 200],
* [0, 30, 300]
* </code>
*
* Nulls and empty lists propagate as null entries in the result.
* <code>
* input: [[5,null,15], 100],
* [null, 200],
* [[], 300]
* index: 0
* output: [0, 5, 100],
* [1, null, 100],
* [2, 15, 100],
* [0, null, 200],
* [0, null, 300]
* </code>
*
* returns
*
* @param index Column index to explode inside the table.
* @return A new table with exploded value and position. The column order of return table is
* [cols before explode_input, explode_position, explode_value, cols after explode_input].
*/
public Table explodeOuterPosition(int index) {
assert 0 <= index && index < columns.length : "Column index is out of range";
assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST";
return new Table(explodeOuterPosition(nativeHandle, index));
}

/**
* Gathers the rows of this table according to `gatherMap` such that row "i"
* in the resulting table's columns will contain row "gatherMap[i]" from this table.
Expand Down
28 changes: 28 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2052,4 +2052,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodePosition(JNIEnv *e
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuter(JNIEnv *env, jclass,
jlong input_jtable,
jint column_index) {
JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input_table = reinterpret_cast<cudf::table_view *>(input_jtable);
cudf::size_type col_index = static_cast<cudf::size_type>(column_index);
std::unique_ptr<cudf::table> exploded = cudf::explode_outer(*input_table, col_index);
return cudf::jni::convert_table_for_return(env, exploded);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuterPosition(JNIEnv *env, jclass,
jlong input_jtable,
jint column_index) {
JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::table_view *input_table = reinterpret_cast<cudf::table_view *>(input_jtable);
cudf::size_type col_index = static_cast<cudf::size_type>(column_index);
std::unique_ptr<cudf::table> exploded = cudf::explode_outer_position(*input_table, col_index);
return cudf::jni::convert_table_for_return(env, exploded);
}
CATCH_STD(env, 0);
}

} // extern "C"
86 changes: 73 additions & 13 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4635,7 +4635,7 @@ private Table[] buildExplodeTestTableWithPrimitiveTypes(boolean pos, boolean out
}
}

private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) {
private Table[] buildExplodeTestTableWithNestedTypes(boolean pos, boolean outer) {
StructType nestedType = new StructType(true,
new BasicType(false, DType.INT32), new BasicType(false, DType.STRING));
try (Table input = new Table.TestBuilder()
Expand All @@ -4644,23 +4644,42 @@ private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) {
Arrays.asList(struct(4, "k4"), struct(5, "k5")),
Arrays.asList(struct(6, "k6")),
Arrays.asList(new HostColumnVector.StructData((List) null)),
Arrays.asList())
null)
.column("s1", "s2", "s3", "s4", "s5")
.column(1, 3, 5, 7, 9)
.column(12.0, 14.0, 13.0, 11.0, 15.0)
.build()) {
Table.TestBuilder expectedBuilder = new Table.TestBuilder();
if (pos) {
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0);
if (!outer)
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0);
else
expectedBuilder.column(0, 1, 2, 0, 1, 0, 0, 0);
}
try (Table expected = expectedBuilder
.column(nestedType,
List<Object[]> expectedData = new ArrayList<Object[]>(){{
if (!outer) {
this.add(new HostColumnVector.StructData[]{
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), struct(5, "k5"), struct(6, "k6"),
new HostColumnVector.StructData((List) null)});
this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4"});
this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7});
this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0});
} else {
this.add(new HostColumnVector.StructData[]{
struct(1, "k1"), struct(2, "k2"), struct(3, "k3"),
struct(4, "k4"), struct(5, "k5"), struct(6, "k6"),
new HostColumnVector.StructData((List) null))
.column("s1", "s1", "s1", "s2", "s2", "s3", "s4")
.column(1, 1, 1, 3, 3, 5, 7)
.column(12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0)
new HostColumnVector.StructData((List) null), null});
this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4", "s5"});
this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7, 9});
this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0, 15.0});
}
}};
try (Table expected = expectedBuilder
.column(nestedType, (HostColumnVector.StructData[]) expectedData.get(0))
.column((String[]) expectedData.get(1))
.column((Integer[]) expectedData.get(2))
.column((Double[]) expectedData.get(3))
.build()) {
return new Table[]{new Table(input.getColumns()), new Table(expected.getColumns())};
}
Expand All @@ -4679,7 +4698,7 @@ void testExplode() {
}

// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false);
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false, false);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explode(0)) {
Expand All @@ -4689,7 +4708,7 @@ void testExplode() {
}

@Test
void testPosExplode() {
void testExplodePosition() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, false);
try (Table input = testTables[0];
Expand All @@ -4699,8 +4718,8 @@ void testPosExplode() {
}
}

// Child is primitive type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true);
// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true, false);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodePosition(0)) {
Expand All @@ -4709,4 +4728,45 @@ void testPosExplode() {
}
}

@Test
void testExplodeOuter() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(false, true);
try (Table input = testTables[0];
Table expected = testTables[1]) {
try (Table exploded = input.explodeOuter(0)) {
assertTablesAreEqual(expected, exploded);
}
}

// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false, true);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodeOuter(0)) {
assertTablesAreEqual(expected, exploded);
}
}
}

@Test
void testExplodeOuterPosition() {
// Child is primitive type
Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, true);
try (Table input = testTables[0];
Table expected = testTables[1]) {
try (Table exploded = input.explodeOuterPosition(0)) {
assertTablesAreEqual(expected, exploded);
}
}

// Child is nested type
Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true, true);
try (Table input = testTables2[0];
Table expected = testTables2[1]) {
try (Table exploded = input.explodeOuterPosition(0)) {
assertTablesAreEqual(expected, exploded);
}
}
}
}

0 comments on commit 99001d2

Please sign in to comment.