From 5ee9b66c24c3686557ea9f60030b9e13d4647a15 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 27 Feb 2024 10:06:47 -0600 Subject: [PATCH 1/2] Java bindings for left outer distinct join --- java/src/main/java/ai/rapids/cudf/Table.java | 27 +++++ java/src/main/native/src/TableJni.cpp | 24 ++++ .../test/java/ai/rapids/cudf/TableTest.java | 114 ++++++++++++++++++ 3 files changed, 165 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 9a790c8518b..0fb0a2e98cc 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -608,6 +608,9 @@ private static native long[] merge(long[] tableHandles, int[] sortKeyIndexes, private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long[] leftDistinctJoinGatherMaps(long leftKeys, long rightKeys, + boolean compareNullsEqual) throws CudfException; + private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException; private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException; @@ -2925,6 +2928,30 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a left outer equi-join between + * two tables where the right table is guaranteed to not contain any duplicated join keys. It is + * assumed this table instance holds the key columns from the left table, and the table argument + * represents the key columns from the right table. Two {@link GatherMap} instances will be + * returned that can be used to gather the left and right tables, respectively, to produce the + * result of the left outer join. + * + * It is the responsibility of the caller to close the resulting gather map instances. + * + * @param rightKeys join key columns from the right table + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] leftDistinctJoinGatherMaps(Table rightKeys, boolean compareNullsEqual) { + if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightKeys.getNumberOfColumns()); + } + long[] gatherMapData = + leftDistinctJoinGatherMaps(getNativeView(), rightKeys.getNativeView(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the number of rows resulting from a left equi-join between two tables. * It is assumed this table instance holds the key columns from the left table, and the diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 1d6f1332b06..cf164e5538b 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -2422,6 +2422,30 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { + return cudf::jni::join_gather_maps( + env, j_left_keys, j_right_keys, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) { + auto has_nulls = cudf::has_nested_nulls(left) || cudf::has_nested_nulls(right) ? + cudf::nullable_join::YES : + cudf::nullable_join::NO; + std::pair>, + std::unique_ptr>> + maps; + if (cudf::detail::has_nested_columns(right)) { + cudf::distinct_hash_join hash(right, left, has_nulls, nulleq); + maps = hash.left_join(); + } else { + cudf::distinct_hash_join hash(right, left, has_nulls, nulleq); + maps = hash.left_join(); + } + // Unique join returns {right map, left map} but all the other joins + // return {left map, right map}. Swap here to make it consistent. + return std::make_pair(std::move(maps.second), std::move(maps.first)); + }); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index e270c4a5183..eff0d3c9ed0 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1697,6 +1697,120 @@ void testLeftJoinGatherMapsNulls() { } } + private void checkLeftDistinctJoin(Table leftKeys, Table rightKeys, Table expected, + boolean compareNullsEqual) { + GatherMap[] maps = leftKeys.leftDistinctJoinGatherMaps(rightKeys, compareNullsEqual); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + + @Test + void testLeftDistinctJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8, 6).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) // left + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3, 0) // right + .build()) { + checkLeftDistinctJoin(leftKeys, rightKeys, expected, false); + } + } + + @Test + void testLeftDistinctJoinGatherMapsWithNested() { + final int inv = Integer.MIN_VALUE; + StructType structType = new StructType(false, + new BasicType(false, DType.STRING), + new BasicType(false, DType.INT32)); + StructData[] leftData = new StructData[]{ + new StructData("abc", 1), + new StructData("xyz", 1), + new StructData("abc", 2), + new StructData("xyz", 2), + new StructData("abc", 1), + new StructData("abc", 3), + new StructData("xyz", 3) + }; + StructData[] rightData = new StructData[]{ + new StructData("abc", 1), + new StructData("xyz", 4), + new StructData("xyz", 2), + new StructData("abc", -1), + }; + try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build(); + Table rightKeys = new Table.TestBuilder().column(structType, rightData).build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6) + .column(0, inv, inv, 2, 0, inv, inv) + .build()) { + checkLeftDistinctJoin(leftKeys, rightKeys, expected, false); + } + } + + @Test + void testLeftDistinctJoinGatherMapsNullsEqual() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // left + .column(inv, inv, 1, inv, inv, inv, inv, 0, 0, 2) // right + .build()) { + checkLeftDistinctJoin(leftKeys, rightKeys, expected, true); + } + } + + @Test + void testLeftDistinctJoinGatherMapsWithNestedNullsEqual() { + final int inv = Integer.MIN_VALUE; + StructType structType = new StructType(true, + new BasicType(true, DType.STRING), + new BasicType(true, DType.INT32)); + StructData[] leftData = new StructData[]{ + new StructData("abc", 1), + null, + new StructData("xyz", 1), + new StructData("abc", 2), + new StructData("xyz", null), + null, + new StructData("abc", 1), + new StructData("abc", 3), + new StructData("xyz", 3), + new StructData(null, null), + new StructData(null, 1) + }; + StructData[] rightData = new StructData[]{ + null, + new StructData("abc", 1), + new StructData("xyz", 4), + new StructData("xyz", 2), + new StructData(null, null), + new StructData(null, 2), + new StructData(null, 1), + new StructData("xyz", null), + new StructData("abc", null), + new StructData("abc", -1) + }; + try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build(); + Table rightKeys = new Table.TestBuilder().column(structType, rightData).build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + .column(1, 0, inv, inv, 7, 0, 1, inv, inv, 4, 6) + .build()) { + checkLeftDistinctJoin(leftKeys, rightKeys, expected, true); + } + } + @Test void testLeftHashJoinGatherMaps() { final int inv = Integer.MIN_VALUE; From e1e0350cc041a14812e1e3ff7ae9b30b3ca32f6c Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 28 Feb 2024 10:47:05 -0600 Subject: [PATCH 2/2] Update to new API that only returns a single gather map --- java/src/main/java/ai/rapids/cudf/Table.java | 51 ++++++++++--------- java/src/main/native/src/TableJni.cpp | 14 ++--- .../test/java/ai/rapids/cudf/TableTest.java | 33 ++++-------- 3 files changed, 41 insertions(+), 57 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 0fb0a2e98cc..fef96759a34 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -608,8 +608,8 @@ private static native long[] merge(long[] tableHandles, int[] sortKeyIndexes, private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; - private static native long[] leftDistinctJoinGatherMaps(long leftKeys, long rightKeys, - boolean compareNullsEqual) throws CudfException; + private static native long[] leftDistinctJoinGatherMap(long leftKeys, long rightKeys, + boolean compareNullsEqual) throws CudfException; private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException; @@ -2929,27 +2929,30 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual } /** - * Computes the gather maps that can be used to manifest the result of a left outer equi-join between - * two tables where the right table is guaranteed to not contain any duplicated join keys. It is - * assumed this table instance holds the key columns from the left table, and the table argument - * represents the key columns from the right table. Two {@link GatherMap} instances will be - * returned that can be used to gather the left and right tables, respectively, to produce the - * result of the left outer join. + * Computes a gather map that can be used to manifest the result of a left equi-join between + * two tables where the right table is guaranteed to not contain any duplicated join keys. + * The left table can be used as-is to produce the left table columns resulting from the join, + * i.e.: left table ordering is preserved in the join result, so no gather map is required for + * the left table. The resulting gather map can be applied to the right table to produce the + * right table columns resulting from the join. It is assumed this table instance holds the + * key columns from the left table, and the table argument represents the key columns from the + * right table. A {@link GatherMap} instance will be returned that can be used to gather the + * right table and that result combined with the left table to produce a left outer join result. * - * It is the responsibility of the caller to close the resulting gather map instances. + * It is the responsibility of the caller to close the resulting gather map instance. * * @param rightKeys join key columns from the right table * @param compareNullsEqual true if null key values should match otherwise false - * @return left and right table gather maps + * @return right table gather map */ - public GatherMap[] leftDistinctJoinGatherMaps(Table rightKeys, boolean compareNullsEqual) { + public GatherMap leftDistinctJoinGatherMap(Table rightKeys, boolean compareNullsEqual) { if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) { throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightKeys.getNumberOfColumns()); } long[] gatherMapData = - leftDistinctJoinGatherMaps(getNativeView(), rightKeys.getNativeView(), compareNullsEqual); - return buildJoinGatherMaps(gatherMapData); + leftDistinctJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3509,7 +3512,7 @@ public static GatherMap[] mixedFullJoinGatherMaps(Table leftKeys, Table rightKey return buildJoinGatherMaps(gatherMapData); } - private static GatherMap buildSemiJoinGatherMap(long[] gatherMapData) { + private static GatherMap buildSingleJoinGatherMap(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; long leftHandle = gatherMapData[2]; @@ -3534,7 +3537,7 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua } long[] gatherMapData = leftSemiJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3567,7 +3570,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, long[] gatherMapData = conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle()); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3592,7 +3595,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, long[] gatherMapData = conditionalLeftSemiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), outputRowCount); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3647,7 +3650,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe leftConditional.getNativeView(), rightConditional.getNativeView(), condition.getNativeHandle(), nullEquality == NullEquality.EQUAL); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3680,7 +3683,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe condition.getNativeHandle(), nullEquality == NullEquality.EQUAL, joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView()); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3701,7 +3704,7 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua } long[] gatherMapData = leftAntiJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3734,7 +3737,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, long[] gatherMapData = conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle()); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3759,7 +3762,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, long[] gatherMapData = conditionalLeftAntiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), outputRowCount); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3814,7 +3817,7 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe leftConditional.getNativeView(), rightConditional.getNativeView(), condition.getNativeHandle(), nullEquality == NullEquality.EQUAL); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** @@ -3847,7 +3850,7 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe condition.getNativeHandle(), nullEquality == NullEquality.EQUAL, joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView()); - return buildSemiJoinGatherMap(gatherMapData); + return buildSingleJoinGatherMap(gatherMapData); } /** diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index cf164e5538b..22b4d10e008 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -2422,27 +2422,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( }); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMaps( +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { - return cudf::jni::join_gather_maps( + return cudf::jni::join_gather_single_map( env, j_left_keys, j_right_keys, compare_nulls_equal, [](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) { auto has_nulls = cudf::has_nested_nulls(left) || cudf::has_nested_nulls(right) ? cudf::nullable_join::YES : cudf::nullable_join::NO; - std::pair>, - std::unique_ptr>> - maps; if (cudf::detail::has_nested_columns(right)) { cudf::distinct_hash_join hash(right, left, has_nulls, nulleq); - maps = hash.left_join(); + return hash.left_join(); } else { cudf::distinct_hash_join hash(right, left, has_nulls, nulleq); - maps = hash.left_join(); + return hash.left_join(); } - // Unique join returns {right map, left map} but all the other joins - // return {left map, right map}. Swap here to make it consistent. - return std::make_pair(std::move(maps.second), std::move(maps.first)); }); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index eff0d3c9ed0..17d663b408f 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1697,14 +1697,13 @@ void testLeftJoinGatherMapsNulls() { } } - private void checkLeftDistinctJoin(Table leftKeys, Table rightKeys, Table expected, + private void checkLeftDistinctJoin(Table leftKeys, Table rightKeys, ColumnView expected, boolean compareNullsEqual) { - GatherMap[] maps = leftKeys.leftDistinctJoinGatherMaps(rightKeys, compareNullsEqual); - try { - verifyJoinGatherMaps(maps, expected); - } finally { - for (GatherMap map : maps) { - map.close(); + try (GatherMap map = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual)) { + int numRows = (int) expected.getRowCount(); + assertEquals(numRows, map.getRowCount()); + try (ColumnView view = map.toColumnView(0, numRows)) { + assertColumnsAreEqual(expected, view); } } } @@ -1714,10 +1713,7 @@ void testLeftDistinctJoinGatherMaps() { final int inv = Integer.MIN_VALUE; try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8, 6).build(); Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); - Table expected = new Table.TestBuilder() - .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) // left - .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3, 0) // right - .build()) { + ColumnVector expected = ColumnVector.fromInts(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3, 0)) { checkLeftDistinctJoin(leftKeys, rightKeys, expected, false); } } @@ -1745,10 +1741,7 @@ void testLeftDistinctJoinGatherMapsWithNested() { }; try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build(); Table rightKeys = new Table.TestBuilder().column(structType, rightData).build(); - Table expected = new Table.TestBuilder() - .column(0, 1, 2, 3, 4, 5, 6) - .column(0, inv, inv, 2, 0, inv, inv) - .build()) { + ColumnVector expected = ColumnVector.fromInts(0, inv, inv, 2, 0, inv, inv)) { checkLeftDistinctJoin(leftKeys, rightKeys, expected, false); } } @@ -1762,10 +1755,7 @@ void testLeftDistinctJoinGatherMapsNullsEqual() { Table rightKeys = new Table.TestBuilder() .column(null, 9, 8, 10, 32) .build(); - Table expected = new Table.TestBuilder() - .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // left - .column(inv, inv, 1, inv, inv, inv, inv, 0, 0, 2) // right - .build()) { + ColumnVector expected = ColumnVector.fromInts(inv, inv, 1, inv, inv, inv, inv, 0, 0, 2)) { checkLeftDistinctJoin(leftKeys, rightKeys, expected, true); } } @@ -1803,10 +1793,7 @@ void testLeftDistinctJoinGatherMapsWithNestedNullsEqual() { }; try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build(); Table rightKeys = new Table.TestBuilder().column(structType, rightData).build(); - Table expected = new Table.TestBuilder() - .column(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - .column(1, 0, inv, inv, 7, 0, 1, inv, inv, 4, 6) - .build()) { + ColumnVector expected = ColumnVector.fromInts(1, 0, inv, inv, 7, 0, 1, inv, inv, 4, 6)) { checkLeftDistinctJoin(leftKeys, rightKeys, expected, true); } }