diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 7385b55d0df..d0e59fdc105 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -516,6 +516,10 @@ private static native long[] repeatColumnCount(long tableHandle, private static native long[] explodePosition(long tableHandle, int index); + private static native long[] explodeOuter(long tableHandle, int index); + + private static native long[] explodeOuterPosition(long tableHandle, int index); + private static native long createCudfTableView(long[] nativeColumnViewHandles); private static native long[] columnViewsFromPacked(ByteBuffer metadata, long dataAddress); @@ -1725,7 +1729,7 @@ public ContiguousTable[] contiguousSplit(int... indices) { * Example: * input: [[5,10,15], 100], * [[20,25], 200], - * [[30], 300], + * [[30], 300] * index: 0 * output: [5, 100], * [10, 100], @@ -1737,12 +1741,12 @@ public ContiguousTable[] contiguousSplit(int... indices) { * * Nulls propagate in different ways depending on what is null. * - * [[5,null,15], 100], - * [null, 200] - * returns: - * [5, 100], - * [null, 100], - * [15, 100] + * input: [[5,null,15], 100], + * [null, 200] + * index: 0 + * output: [5, 100], + * [null, 100], + * [15, 100] * * Note that null lists are completely removed from the output * and nulls inside lists are pulled out and remain. @@ -1763,27 +1767,26 @@ public Table explode(int index) { * in the output. The corresponding rows for other columns in the input are duplicated. A position * column is added that has the index inside the original list for each row. Example: * - * [[5,10,15], 100], - * [[20,25], 200], - * [[30], 300], - * returns - * [0, 5, 100], - * [1, 10, 100], - * [2, 15, 100], - * [0, 20, 200], - * [1, 25, 200], - * [0, 30, 300], + * input: [[5,10,15], 100], + * [[20,25], 200], + * [[30], 300] + * index: 0 + * output: [0, 5, 100], + * [1, 10, 100], + * [2, 15, 100], + * [0, 20, 200], + * [1, 25, 200], + * [0, 30, 300] * * * Nulls and empty lists propagate in different ways depending on what is null or empty. * - * [[5,null,15], 100], - * [null, 200], - * [[], 300], - * returns - * [0, 5, 100], - * [1, null, 100], - * [2, 15, 100], + * input: [[5,null,15], 100], + * [null, 200] + * index: 0 + * output: [5, 100], + * [null, 100], + * [15, 100] * * * Note that null lists are not included in the resulting table, but nulls inside @@ -1799,6 +1802,96 @@ public Table explodePosition(int index) { return new Table(explodePosition(nativeHandle, index)); } + /** + * Explodes a list column's elements. + * + * Any list is exploded, which means the elements of the list in each row are expanded + * into new rows in the output. The corresponding rows for other columns in the input + * are duplicated. + * + * + * Example: + * input: [[5,10,15], 100], + * [[20,25], 200], + * [[30], 300], + * index: 0 + * output: [5, 100], + * [10, 100], + * [15, 100], + * [20, 200], + * [25, 200], + * [30, 300] + * + * + * Nulls propagate in different ways depending on what is null. + * + * input: [[5,null,15], 100], + * [null, 200] + * index: 0 + * output: [5, 100], + * [null, 100], + * [15, 100], + * [null, 200] + * + * Note that null lists are completely removed from the output + * and nulls inside lists are pulled out and remain. + * + * @param index Column index to explode inside the table. + * @return A new table with explode_col exploded. + */ + public Table explodeOuter(int index) { + assert 0 <= index && index < columns.length : "Column index is out of range"; + assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST"; + return new Table(explodeOuter(nativeHandle, index)); + } + + /** + * Explodes a list column's elements retaining any null entries or empty lists and includes a + * position column. + * + * Any list is exploded, which means the elements of the list in each row are expanded into new rows + * in the output. The corresponding rows for other columns in the input are duplicated. A position + * column is added that has the index inside the original list for each row. Example: + * + * + * Example: + * input: [[5,10,15], 100], + * [[20,25], 200], + * [[30], 300], + * index: 0 + * output: [0, 5, 100], + * [1, 10, 100], + * [2, 15, 100], + * [0, 20, 200], + * [1, 25, 200], + * [0, 30, 300] + * + * + * Nulls and empty lists propagate as null entries in the result. + * + * input: [[5,null,15], 100], + * [null, 200], + * [[], 300] + * index: 0 + * output: [0, 5, 100], + * [1, null, 100], + * [2, 15, 100], + * [0, null, 200], + * [0, null, 300] + * + * + * returns + * + * @param index Column index to explode inside the table. + * @return A new table with exploded value and position. The column order of return table is + * [cols before explode_input, explode_position, explode_value, cols after explode_input]. + */ + public Table explodeOuterPosition(int index) { + assert 0 <= index && index < columns.length : "Column index is out of range"; + assert columns[index].getType().equals(DType.LIST) : "Column to explode must be of type LIST"; + return new Table(explodeOuterPosition(nativeHandle, index)); + } + /** * Gathers the rows of this table according to `gatherMap` such that row "i" * in the resulting table's columns will contain row "gatherMap[i]" from this table. diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 4548156055a..02385a453d0 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -2052,4 +2052,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodePosition(JNIEnv *e CATCH_STD(env, 0); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuter(JNIEnv *env, jclass, + jlong input_jtable, + jint column_index) { + JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::table_view *input_table = reinterpret_cast(input_jtable); + cudf::size_type col_index = static_cast(column_index); + std::unique_ptr exploded = cudf::explode_outer(*input_table, col_index); + return cudf::jni::convert_table_for_return(env, exploded); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_explodeOuterPosition(JNIEnv *env, jclass, + jlong input_jtable, + jint column_index) { + JNI_NULL_CHECK(env, input_jtable, "explode: input table is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::table_view *input_table = reinterpret_cast(input_jtable); + cudf::size_type col_index = static_cast(column_index); + std::unique_ptr exploded = cudf::explode_outer_position(*input_table, col_index); + return cudf::jni::convert_table_for_return(env, exploded); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 626f7828012..c2e28e1cad8 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -4635,7 +4635,7 @@ private Table[] buildExplodeTestTableWithPrimitiveTypes(boolean pos, boolean out } } - private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) { + private Table[] buildExplodeTestTableWithNestedTypes(boolean pos, boolean outer) { StructType nestedType = new StructType(true, new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); try (Table input = new Table.TestBuilder() @@ -4644,23 +4644,42 @@ private Table[] buildExplodeTestTableWithNestedTypes(boolean pos) { Arrays.asList(struct(4, "k4"), struct(5, "k5")), Arrays.asList(struct(6, "k6")), Arrays.asList(new HostColumnVector.StructData((List) null)), - Arrays.asList()) + null) .column("s1", "s2", "s3", "s4", "s5") .column(1, 3, 5, 7, 9) .column(12.0, 14.0, 13.0, 11.0, 15.0) .build()) { Table.TestBuilder expectedBuilder = new Table.TestBuilder(); if (pos) { - expectedBuilder.column(0, 1, 2, 0, 1, 0, 0); + if (!outer) + expectedBuilder.column(0, 1, 2, 0, 1, 0, 0); + else + expectedBuilder.column(0, 1, 2, 0, 1, 0, 0, 0); } - try (Table expected = expectedBuilder - .column(nestedType, + List expectedData = new ArrayList(){{ + if (!outer) { + this.add(new HostColumnVector.StructData[]{ + struct(1, "k1"), struct(2, "k2"), struct(3, "k3"), + struct(4, "k4"), struct(5, "k5"), struct(6, "k6"), + new HostColumnVector.StructData((List) null)}); + this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4"}); + this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7}); + this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0}); + } else { + this.add(new HostColumnVector.StructData[]{ struct(1, "k1"), struct(2, "k2"), struct(3, "k3"), struct(4, "k4"), struct(5, "k5"), struct(6, "k6"), - new HostColumnVector.StructData((List) null)) - .column("s1", "s1", "s1", "s2", "s2", "s3", "s4") - .column(1, 1, 1, 3, 3, 5, 7) - .column(12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0) + new HostColumnVector.StructData((List) null), null}); + this.add(new String[]{"s1", "s1", "s1", "s2", "s2", "s3", "s4", "s5"}); + this.add(new Integer[]{1, 1, 1, 3, 3, 5, 7, 9}); + this.add(new Double[]{12.0, 12.0, 12.0, 14.0, 14.0, 13.0, 11.0, 15.0}); + } + }}; + try (Table expected = expectedBuilder + .column(nestedType, (HostColumnVector.StructData[]) expectedData.get(0)) + .column((String[]) expectedData.get(1)) + .column((Integer[]) expectedData.get(2)) + .column((Double[]) expectedData.get(3)) .build()) { return new Table[]{new Table(input.getColumns()), new Table(expected.getColumns())}; } @@ -4679,7 +4698,7 @@ void testExplode() { } // Child is nested type - Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false); + Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false, false); try (Table input = testTables2[0]; Table expected = testTables2[1]) { try (Table exploded = input.explode(0)) { @@ -4689,7 +4708,7 @@ void testExplode() { } @Test - void testPosExplode() { + void testExplodePosition() { // Child is primitive type Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, false); try (Table input = testTables[0]; @@ -4699,8 +4718,8 @@ void testPosExplode() { } } - // Child is primitive type - Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true); + // Child is nested type + Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true, false); try (Table input = testTables2[0]; Table expected = testTables2[1]) { try (Table exploded = input.explodePosition(0)) { @@ -4709,4 +4728,45 @@ void testPosExplode() { } } + @Test + void testExplodeOuter() { + // Child is primitive type + Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(false, true); + try (Table input = testTables[0]; + Table expected = testTables[1]) { + try (Table exploded = input.explodeOuter(0)) { + assertTablesAreEqual(expected, exploded); + } + } + + // Child is nested type + Table[] testTables2 = buildExplodeTestTableWithNestedTypes(false, true); + try (Table input = testTables2[0]; + Table expected = testTables2[1]) { + try (Table exploded = input.explodeOuter(0)) { + assertTablesAreEqual(expected, exploded); + } + } + } + + @Test + void testExplodeOuterPosition() { + // Child is primitive type + Table[] testTables = buildExplodeTestTableWithPrimitiveTypes(true, true); + try (Table input = testTables[0]; + Table expected = testTables[1]) { + try (Table exploded = input.explodeOuterPosition(0)) { + assertTablesAreEqual(expected, exploded); + } + } + + // Child is nested type + Table[] testTables2 = buildExplodeTestTableWithNestedTypes(true, true); + try (Table input = testTables2[0]; + Table expected = testTables2[1]) { + try (Table exploded = input.explodeOuterPosition(0)) { + assertTablesAreEqual(expected, exploded); + } + } + } }