From f3480fde08210cdae728dd106d51fc3d8727e5d1 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 18 Aug 2021 17:21:24 -0500 Subject: [PATCH] Java bindings for cudf::hash_join --- .../main/java/ai/rapids/cudf/HashJoin.java | 127 ++++++++ .../java/ai/rapids/cudf/MemoryCleaner.java | 4 + java/src/main/java/ai/rapids/cudf/Table.java | 220 +++++++++++++ java/src/main/native/CMakeLists.txt | 1 + java/src/main/native/src/HashJoinJni.cpp | 45 +++ java/src/main/native/src/TableJni.cpp | 237 ++++++++++---- .../java/ai/rapids/cudf/HashJoinTest.java | 45 +++ .../test/java/ai/rapids/cudf/TableTest.java | 288 +++++++++++++++++- 8 files changed, 900 insertions(+), 67 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/HashJoin.java create mode 100644 java/src/main/native/src/HashJoinJni.cpp create mode 100644 java/src/test/java/ai/rapids/cudf/HashJoinTest.java diff --git a/java/src/main/java/ai/rapids/cudf/HashJoin.java b/java/src/main/java/ai/rapids/cudf/HashJoin.java new file mode 100644 index 00000000000..620a7ce6a6c --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/HashJoin.java @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class represents a hash table built from the join keys of the right-side table for a + * join operation. This hash table can then be reused across a series of left probe tables + * to compute gather maps for joins more efficiently when the right-side table is not changing. + * It can also be used to query the output row count of a join and then pass that result to the + * operation that generates the join gather maps to avoid redundant computation when the output + * row count must be checked before manifesting the join gather maps. + */ +public class HashJoin implements AutoCloseable { + static { + NativeDepsLoader.loadNativeDeps(); + } + + private static final Logger log = LoggerFactory.getLogger(HashJoin.class); + + private static class HashJoinCleaner extends MemoryCleaner.Cleaner { + private Table buildKeys; + private long nativeHandle; + + HashJoinCleaner(Table buildKeys, long nativeHandle) { + this.buildKeys = buildKeys; + this.nativeHandle = nativeHandle; + addRef(); + } + + @Override + protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) { + long origAddress = nativeHandle; + boolean neededCleanup = nativeHandle != 0; + if (neededCleanup) { + try { + destroy(nativeHandle); + buildKeys.close(); + buildKeys = null; + } finally { + nativeHandle = 0; + } + if (logErrorIfNotClean) { + log.error("A HASH TABLE WAS LEAKED (ID: " + id + " " + Long.toHexString(origAddress)); + } + } + return neededCleanup; + } + + @Override + public boolean isClean() { + return nativeHandle == 0; + } + } + + private final HashJoinCleaner cleaner; + private final boolean compareNulls; + private boolean isClosed = false; + + /** + * Construct a hash table for a join from a table representing the join key columns from the + * right-side table in the join. The resulting instance must be closed to release the + * GPU resources associated with the instance. + * @param buildKeys table view containing the join keys for the right-side join table + * @param compareNulls true if null key values should match otherwise false + */ + public HashJoin(Table buildKeys, boolean compareNulls) { + this.compareNulls = compareNulls; + Table buildTable = new Table(buildKeys.getColumns()); + try { + long handle = create(buildTable.getNativeView(), compareNulls); + this.cleaner = new HashJoinCleaner(buildTable, handle); + MemoryCleaner.register(this, cleaner); + } catch (Throwable t) { + try { + buildTable.close(); + } catch (Throwable t2) { + t.addSuppressed(t2); + } + throw t; + } + } + + @Override + public synchronized void close() { + cleaner.delRef(); + if (isClosed) { + cleaner.logRefCountDebug("double free " + this); + throw new IllegalStateException("Close called too many times " + this); + } + cleaner.clean(false); + isClosed = true; + } + + long getNativeView() { + return cleaner.nativeHandle; + } + + /** Get the number of join key columns for the table that was used to generate the has table. */ + public long getNumberOfColumns() { + return cleaner.buildKeys.getNumberOfColumns(); + } + + /** Returns true if the hash table was built to match on nulls otherwise false. */ + public boolean getCompareNulls() { + return compareNulls; + } + + private static native long create(long tableView, boolean nullEqual); + private static native void destroy(long handle); +} diff --git a/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java b/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java index 4bf38543a2d..a936d4830ee 100644 --- a/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java +++ b/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java @@ -277,6 +277,10 @@ public static void register(CompiledExpression expr, Cleaner cleaner) { all.add(new CleanerWeakReference(expr, cleaner, collected, false)); } + static void register(HashJoin hashJoin, Cleaner cleaner) { + all.add(new CleanerWeakReference(hashJoin, cleaner, collected, true)); + } + /** * This is not 100% perfect and we can still run into situations where RMM buffers were not * collected and this returns false because of thread race conditions. This is just a best effort. diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 1fc9616d607..e725932ed5e 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -505,18 +505,48 @@ private static native long[] leftJoin(long leftTable, int[] leftJoinCols, long r private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long leftJoinRowCount(long leftTable, long rightHashJoin, + boolean nullsEqual) throws CudfException; + + private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin, + boolean nullsEqual) throws CudfException; + + private static native long[] leftHashJoinGatherMapsWithCount(long leftTable, long rightHashJoin, + boolean nullsEqual, + long outputRowCount) throws CudfException; + private static native long[] innerJoin(long leftTable, int[] leftJoinCols, long rightTable, int[] rightJoinCols, boolean compareNullsEqual) throws CudfException; private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long innerJoinRowCount(long table, long hashJoin, + boolean nullsEqual) throws CudfException; + + private static native long[] innerHashJoinGatherMaps(long table, long hashJoin, + boolean nullsEqual) throws CudfException; + + private static native long[] innerHashJoinGatherMapsWithCount(long table, long hashJoin, + boolean nullsEqual, + long outputRowCount) throws CudfException; + private static native long[] fullJoin(long leftTable, int[] leftJoinCols, long rightTable, int[] rightJoinCols, boolean compareNullsEqual) throws CudfException; private static native long[] fullJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long fullJoinRowCount(long leftTable, long rightHashJoin, + boolean nullsEqual) throws CudfException; + + private static native long[] fullHashJoinGatherMaps(long leftTable, long rightHashJoin, + boolean nullsEqual) throws CudfException; + + private static native long[] fullHashJoinGatherMapsWithCount(long leftTable, long rightHashJoin, + boolean nullsEqual, + long outputRowCount) throws CudfException; + private static native long[] leftSemiJoin(long leftTable, int[] leftJoinCols, long rightTable, int[] rightJoinCols, boolean compareNullsEqual) throws CudfException; @@ -2040,6 +2070,69 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the number of rows resulting from a left equi-join between two tables. + * It is assumed this table instance holds the key columns from the left table, and the + * {@link HashJoin} argument has been constructed from the key columns from the right table. + * @param rightHash hash table built from join key columns from the right table + * @return row count of the join result + */ + public long leftJoinRowCount(HashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + return leftJoinRowCount(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls()); + } + + /** + * Computes the gather maps that can be used to manifest the result of a left equi-join between + * two tables. It is assumed this table instance holds the key columns from the left table, and + * the {@link HashJoin} argument has been constructed from the key columns from the right table. + * Two {@link GatherMap} instances will be returned that can be used to gather the left and right + * tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightHash hash table built from join key columns from the right table + * @return left and right table gather maps + */ + public GatherMap[] leftJoinGatherMaps(HashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + leftHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls()); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of a left equi-join between + * two tables. It is assumed this table instance holds the key columns from the left table, and + * the {@link HashJoin} argument has been constructed from the key columns from the right table. + * Two {@link GatherMap} instances will be returned that can be used to gather the left and right + * tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing an output row count that was previously computed from + * {@link #leftJoinRowCount(HashJoin)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightHash hash table built from join key columns from the right table + * @param outputRowCount number of output rows in the join result + * @return left and right table gather maps + */ + public GatherMap[] leftJoinGatherMaps(HashJoin rightHash, long outputRowCount) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + leftHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls(), outputRowCount); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the number of rows from the result of a left join between two tables when a * conditional expression is true. It is assumed this table instance holds the columns from @@ -2124,6 +2217,67 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the number of rows resulting from an inner equi-join between two tables. + * @param otherHash hash table built from join key columns from the other table + * @return row count of the join result + */ + public long innerJoinRowCount(HashJoin otherHash) { + if (getNumberOfColumns() != otherHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "otherKeys: " + otherHash.getNumberOfColumns()); + } + return innerJoinRowCount(getNativeView(), otherHash.getNativeView(), + otherHash.getCompareNulls()); + } + + /** + * Computes the gather maps that can be used to manifest the result of an inner equi-join between + * two tables. It is assumed this table instance holds the key columns from the left table, and + * the {@link HashJoin} argument has been constructed from the key columns from the right table. + * Two {@link GatherMap} instances will be returned that can be used to gather the left and right + * tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightHash hash table built from join key columns from the right table + * @return left and right table gather maps + */ + public GatherMap[] innerJoinGatherMaps(HashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + innerHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls()); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of an inner equi-join between + * two tables. It is assumed this table instance holds the key columns from the left table, and + * the {@link HashJoin} argument has been constructed from the key columns from the right table. + * Two {@link GatherMap} instances will be returned that can be used to gather the left and right + * tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing an output row count that was previously computed from + * {@link #innerJoinRowCount(HashJoin)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightHash hash table built from join key columns from the right table + * @param outputRowCount number of output rows in the join result + * @return left and right table gather maps + */ + public GatherMap[] innerJoinGatherMaps(HashJoin rightHash, long outputRowCount) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + innerHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls(), outputRowCount); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the number of rows from the result of an inner join between two tables when a * conditional expression is true. It is assumed this table instance holds the columns from @@ -2209,6 +2363,72 @@ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the number of rows resulting from a full equi-join between two tables. + * It is assumed this table instance holds the key columns from the left table, and the + * {@link HashJoin} argument has been constructed from the key columns from the right table. + * Note that unlike {@link #leftJoinRowCount(HashJoin)} and {@link #innerJoinRowCount(HashJoin), + * this will perform some redundant calculations compared to + * {@link #fullJoinGatherMaps(HashJoin, long)}. + * @param rightHash hash table built from join key columns from the right table + * @return row count of the join result + */ + public long fullJoinRowCount(HashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + return fullJoinRowCount(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls()); + } + + /** + * Computes the gather maps that can be used to manifest the result of a full equi-join between + * two tables. It is assumed this table instance holds the key columns from the left table, and + * the {@link HashJoin} argument has been constructed from the key columns from the right table. + * Two {@link GatherMap} instances will be returned that can be used to gather the left and right + * tables, respectively, to produce the result of the full join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightHash hash table built from join key columns from the right table + * @return left and right table gather maps + */ + public GatherMap[] fullJoinGatherMaps(HashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + fullHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls()); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of a full equi-join between + * two tables. It is assumed this table instance holds the key columns from the left table, and + * the {@link HashJoin} argument has been constructed from the key columns from the right table. + * Two {@link GatherMap} instances will be returned that can be used to gather the left and right + * tables, respectively, to produce the result of the full join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing an output row count that was previously computed from + * {@link #fullJoinRowCount(HashJoin)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightHash hash table built from join key columns from the right table + * @param outputRowCount number of output rows in the join result + * @return left and right table gather maps + */ + public GatherMap[] fullJoinGatherMaps(HashJoin rightHash, long outputRowCount) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + fullHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(), + rightHash.getCompareNulls(), outputRowCount); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of a full join between * two tables when a conditional expression is true. It is assumed this table instance holds diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index 35ecae681b8..bc59e3aee64 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -264,6 +264,7 @@ set(SOURCE_FILES "src/ColumnViewJni.cpp" "src/CompiledExpression.cpp" "src/ContiguousTableJni.cpp" + "src/HashJoinJni.cpp" "src/HostMemoryBufferNativeUtilsJni.cpp" "src/NvcompJni.cpp" "src/NvtxRangeJni.cpp" diff --git a/java/src/main/native/src/HashJoinJni.cpp b/java/src/main/native/src/HashJoinJni.cpp new file mode 100644 index 00000000000..0f78aef64bc --- /dev/null +++ b/java/src/main/native/src/HashJoinJni.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "cudf_jni_apis.hpp" + +extern "C" { + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_HashJoin_create(JNIEnv *env, jclass, jlong j_table, + jboolean j_nulls_equal) { + JNI_NULL_CHECK(env, j_table, "table handle is null", 0); + try { + cudf::jni::auto_set_device(env); + auto tview = reinterpret_cast(j_table); + auto nulleq = j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto hash_join_ptr = new cudf::hash_join(*tview, nulleq); + return reinterpret_cast(hash_join_ptr); + } + CATCH_STD(env, 0); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_HashJoin_destroy(JNIEnv *env, jclass, jlong j_handle) { + try { + cudf::jni::auto_set_device(env); + auto hash_join_ptr = reinterpret_cast(j_handle); + delete hash_join_ptr; + } + CATCH_STD(env, ); +} + +} // extern "C" diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 595bc1df151..f642a87b445 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -755,13 +755,46 @@ bool valid_window_parameters(native_jintArray const &values, values.size() == preceding.size() && values.size() == following.size(); } -// Generate gather maps needed to manifest the result of an equi-join between two tables. +// Convert a cudf gather map pair into the form that Java expects // The resulting Java long array contains the following at each index: // 0: Size of each gather map in bytes // 1: Device address of the gather map for the left table // 2: Host address of the rmm::device_buffer instance that owns the left gather map data // 3: Device address of the gather map for the right table // 4: Host address of the rmm::device_buffer instance that owns the right gather map data +jlongArray gather_maps_to_java(JNIEnv *env, + std::pair>, + std::unique_ptr>> + maps) { + // release the underlying device buffer to Java + auto left_map_buffer = std::make_unique(maps.first->release()); + auto right_map_buffer = std::make_unique(maps.second->release()); + cudf::jni::native_jlongArray result(env, 5); + result[0] = static_cast(left_map_buffer->size()); + result[1] = reinterpret_cast(left_map_buffer->data()); + result[2] = reinterpret_cast(left_map_buffer.release()); + result[3] = reinterpret_cast(right_map_buffer->data()); + result[4] = reinterpret_cast(right_map_buffer.release()); + return result.get_jArray(); +} + +// Convert a cudf gather map into the form that Java expects +// The resulting Java long array contains the following at each index: +// 0: Size of the gather map in bytes +// 1: Device address of the gather map +// 2: Host address of the rmm::device_buffer instance that owns the gather map data +jlongArray gather_map_to_java(JNIEnv *env, + std::unique_ptr> map) { + // release the underlying device buffer to Java + auto gather_map_buffer = std::make_unique(map->release()); + cudf::jni::native_jlongArray result(env, 3); + result[0] = static_cast(gather_map_buffer->size()); + result[1] = reinterpret_cast(gather_map_buffer->data()); + result[2] = reinterpret_cast(gather_map_buffer.release()); + return result.get_jArray(); +} + +// Generate gather maps needed to manifest the result of an equi-join between two tables. template jlongArray join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal, T join_func) { @@ -772,31 +805,29 @@ jlongArray join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, auto left_keys = reinterpret_cast(j_left_keys); auto right_keys = reinterpret_cast(j_right_keys); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - std::pair>, - std::unique_ptr>> - join_maps = join_func(*left_keys, *right_keys, nulleq); - - // release the underlying device buffer to Java - auto left_map_buffer = std::make_unique(join_maps.first->release()); - auto right_map_buffer = std::make_unique(join_maps.second->release()); - cudf::jni::native_jlongArray result(env, 5); - result[0] = static_cast(left_map_buffer->size()); - result[1] = reinterpret_cast(left_map_buffer->data()); - result[2] = reinterpret_cast(left_map_buffer.release()); - result[3] = reinterpret_cast(right_map_buffer->data()); - result[4] = reinterpret_cast(right_map_buffer.release()); - return result.get_jArray(); + return gather_maps_to_java(env, join_func(*left_keys, *right_keys, nulleq)); + } + CATCH_STD(env, NULL); +} + +// Generate gather maps needed to manifest the result of an equi-join between a left table and +// a hash table built from the join's right table. +template +jlongArray hash_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_hash_join, + jboolean compare_nulls_equal, T join_func) { + JNI_NULL_CHECK(env, j_left_keys, "left table is null", NULL); + JNI_NULL_CHECK(env, j_right_hash_join, "hash join is null", NULL); + try { + cudf::jni::auto_set_device(env); + auto left_keys = reinterpret_cast(j_left_keys); + auto hash_join = reinterpret_cast(j_right_hash_join); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + return gather_maps_to_java(env, join_func(*left_keys, *hash_join, nulleq)); } CATCH_STD(env, NULL); } // Generate gather maps needed to manifest the result of a conditional join between two tables. -// The resulting Java long array contains the following at each index: -// 0: Size of each gather map in bytes -// 1: Device address of the gather map for the left table -// 2: Host address of the rmm::device_buffer instance that owns the left gather map data -// 3: Device address of the gather map for the right table -// 4: Host address of the rmm::device_buffer instance that owns the right gather map data template jlongArray cond_join_gather_maps(JNIEnv *env, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal, T join_func) { @@ -809,29 +840,13 @@ jlongArray cond_join_gather_maps(JNIEnv *env, jlong j_left_table, jlong j_right_ auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - std::pair>, - std::unique_ptr>> - join_maps = join_func(*left_table, *right_table, condition->get_top_expression(), nulleq); - - // release the underlying device buffer to Java - auto left_map_buffer = std::make_unique(join_maps.first->release()); - auto right_map_buffer = std::make_unique(join_maps.second->release()); - cudf::jni::native_jlongArray result(env, 5); - result[0] = static_cast(left_map_buffer->size()); - result[1] = reinterpret_cast(left_map_buffer->data()); - result[2] = reinterpret_cast(left_map_buffer.release()); - result[3] = reinterpret_cast(right_map_buffer->data()); - result[4] = reinterpret_cast(right_map_buffer.release()); - return result.get_jArray(); + return gather_maps_to_java( + env, join_func(*left_table, *right_table, condition->get_top_expression(), nulleq)); } CATCH_STD(env, NULL); } // Generate a gather map needed to manifest the result of a semi/anti join between two tables. -// The resulting Java long array contains the following at each index: -// 0: Size of the gather map in bytes -// 1: Device address of the gather map -// 2: Host address of the rmm::device_buffer instance that owns the gather map data template jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal, T join_func) { @@ -842,26 +857,13 @@ jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_ auto left_keys = reinterpret_cast(j_left_keys); auto right_keys = reinterpret_cast(j_right_keys); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - std::unique_ptr> join_map = - join_func(*left_keys, *right_keys, nulleq); - - // release the underlying device buffer to Java - auto gather_map_buffer = std::make_unique(join_map->release()); - cudf::jni::native_jlongArray result(env, 3); - result[0] = static_cast(gather_map_buffer->size()); - result[1] = reinterpret_cast(gather_map_buffer->data()); - result[2] = reinterpret_cast(gather_map_buffer.release()); - return result.get_jArray(); + return gather_map_to_java(env, join_func(*left_keys, *right_keys, nulleq)); } CATCH_STD(env, NULL); } // Generate a gather map needed to manifest the result of a conditional semi/anti join // between two tables. -// The resulting Java long array contains the following at each index: -// 0: Size of the gather map in bytes -// 1: Device address of the gather map -// 2: Host address of the rmm::device_buffer instance that owns the gather map data template jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal, @@ -875,16 +877,8 @@ jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_ auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - std::unique_ptr> join_map = - join_func(*left_table, *right_table, condition->get_top_expression(), nulleq); - - // release the underlying device buffer to Java - auto gather_map_buffer = std::make_unique(join_map->release()); - cudf::jni::native_jlongArray result(env, 3); - result[0] = static_cast(gather_map_buffer->size()); - result[1] = reinterpret_cast(gather_map_buffer->data()); - result[2] = reinterpret_cast(gather_map_buffer.release()); - return result.get_jArray(); + return gather_map_to_java( + env, join_func(*left_table, *right_table, condition->get_top_expression(), nulleq)); } CATCH_STD(env, NULL); } @@ -1951,6 +1945,45 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv *env, jclass, + jlong j_left_table, + jlong j_right_hash_join, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left table is null", 0); + JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto hash_join = reinterpret_cast(j_right_hash_join); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = hash_join->left_join_size(*left_table, nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, + jboolean compare_nulls_equal) { + return cudf::jni::hash_join_gather_maps( + env, j_left_table, j_right_hash_join, compare_nulls_equal, + [](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) { + return hash.left_join(left, nulleq); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMapsWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal, + jlong j_output_row_count) { + auto output_row_count = static_cast(j_output_row_count); + return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal, + [output_row_count](cudf::table_view const &left, + cudf::hash_join const &hash, + cudf::null_equality nulleq) { + return hash.left_join(left, nulleq, output_row_count); + }); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { @@ -2002,6 +2035,45 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv *env, jclass, + jlong j_left_table, + jlong j_right_hash_join, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left table is null", 0); + JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto hash_join = reinterpret_cast(j_right_hash_join); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = hash_join->inner_join_size(*left_table, nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, + jboolean compare_nulls_equal) { + return cudf::jni::hash_join_gather_maps( + env, j_left_table, j_right_hash_join, compare_nulls_equal, + [](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) { + return hash.inner_join(left, nulleq); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMapsWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal, + jlong j_output_row_count) { + auto output_row_count = static_cast(j_output_row_count); + return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal, + [output_row_count](cudf::table_view const &left, + cudf::hash_join const &hash, + cudf::null_equality nulleq) { + return hash.inner_join(left, nulleq, output_row_count); + }); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { @@ -2053,6 +2125,45 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_fullJoinRowCount(JNIEnv *env, jclass, + jlong j_left_table, + jlong j_right_hash_join, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left table is null", 0); + JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto hash_join = reinterpret_cast(j_right_hash_join); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = hash_join->full_join_size(*left_table, nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, + jboolean compare_nulls_equal) { + return cudf::jni::hash_join_gather_maps( + env, j_left_table, j_right_hash_join, compare_nulls_equal, + [](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) { + return hash.full_join(left, nulleq); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMapsWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal, + jlong j_output_row_count) { + auto output_row_count = static_cast(j_output_row_count); + return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal, + [output_row_count](cudf::table_view const &left, + cudf::hash_join const &hash, + cudf::null_equality nulleq) { + return hash.full_join(left, nulleq, output_row_count); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { diff --git a/java/src/test/java/ai/rapids/cudf/HashJoinTest.java b/java/src/test/java/ai/rapids/cudf/HashJoinTest.java new file mode 100644 index 00000000000..be6125340ec --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/HashJoinTest.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HashJoinTest { + @Test + void testGetNumberOfColumns() { + try (Table t = new Table.TestBuilder().column(1, 2).column(3, 4).column(5, 6).build(); + HashJoin hashJoin = new HashJoin(t, false)) { + assertEquals(3, hashJoin.getNumberOfColumns()); + } + } + + @Test + void testGetCompareNulls() { + try (Table t = new Table.TestBuilder().column(1, 2, 3, 4).column(5, 6, 7, 8).build()) { + try (HashJoin hashJoin = new HashJoin(t, false)) { + assertFalse(hashJoin.getCompareNulls()); + } + try (HashJoin hashJoin = new HashJoin(t, true)) { + assertTrue(hashJoin.getCompareNulls()); + } + } + } +} diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 8e4e3df612b..aeb94e4824a 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -30,10 +30,6 @@ import ai.rapids.cudf.ast.ColumnReference; import ai.rapids.cudf.ast.CompiledExpression; import ai.rapids.cudf.ast.TableReference; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.ipc.ArrowFileReader; -import org.apache.arrow.vector.ipc.SeekableReadChannel; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.parquet.hadoop.ParquetFileReader; @@ -1500,6 +1496,102 @@ void testLeftJoinGatherMapsNulls() { } } + @Test + void testLeftHashJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + HashJoin rightHash = new HashJoin(rightKeys, false); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3) + .build()) { + GatherMap[] maps = leftKeys.leftJoinGatherMaps(rightHash); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testLeftHashJoinGatherMapsWithCount() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + HashJoin rightHash = new HashJoin(rightKeys, false); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3) + .build()) { + long rowCount = leftKeys.leftJoinRowCount(rightHash); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = leftKeys.leftJoinGatherMaps(rightHash, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testLeftHashJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + HashJoin rightHash = new HashJoin(rightKeys, true); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build()) { + GatherMap[] maps = leftKeys.leftJoinGatherMaps(rightHash); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testLeftHashJoinGatherMapsNullsWithCount() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + HashJoin rightHash = new HashJoin(rightKeys,true); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build()) { + long rowCount = leftKeys.leftJoinRowCount(rightHash); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = leftKeys.leftJoinGatherMaps(rightHash, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testConditionalLeftJoinGatherMaps() { final int inv = Integer.MIN_VALUE; @@ -1654,6 +1746,98 @@ void testInnerJoinGatherMapsNulls() { } } + @Test + void testInnerHashJoinGatherMaps() { + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + HashJoin rightHash = new HashJoin(rightKeys, false); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9) // left + .column(2, 0, 1, 3) // right + .build()) { + GatherMap[] maps = leftKeys.innerJoinGatherMaps(rightHash); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testInnerHashJoinGatherMapsWithCount() { + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + HashJoin rightHash = new HashJoin(rightKeys, false); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9) // left + .column(2, 0, 1, 3) // right + .build()) { + long rowCount = leftKeys.innerJoinRowCount(rightHash); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = leftKeys.innerJoinGatherMaps(rightHash, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testInnerHashJoinGatherMapsNulls() { + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + HashJoin rightHash = new HashJoin(rightKeys, true); + Table expected = new Table.TestBuilder() + .column(2, 7, 7, 8, 8, 9) // left + .column(2, 0, 1, 0, 1, 3) // right + .build()) { + GatherMap[] maps = leftKeys.innerJoinGatherMaps(rightHash); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testInnerHashJoinGatherMapsNullsWithCount() { + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + HashJoin rightHash = new HashJoin(rightKeys, true); + Table expected = new Table.TestBuilder() + .column(2, 7, 7, 8, 8, 9) // left + .column(2, 0, 1, 0, 1, 3) // right + .build()) { + long rowCount = leftKeys.innerJoinRowCount(rightHash); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = leftKeys.innerJoinGatherMaps(rightHash, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testConditionalInnerJoinGatherMaps() { BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, @@ -1806,6 +1990,102 @@ void testFullJoinGatherMapsNulls() { } } + @Test + void testFullHashJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, null, 1, 7, 4, 6, 5, 8).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, null).build(); + HashJoin rightHash = new HashJoin(rightKeys, false); + Table expected = new Table.TestBuilder() + .column(inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // left + .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 3) // right + .build()) { + GatherMap[] maps = leftKeys.fullJoinGatherMaps(rightHash); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testFullHashJoinGatherMapsWithCount() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, null, 1, 7, 4, 6, 5, 8).build(); + Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, null).build(); + HashJoin rightHash = new HashJoin(rightKeys, false); + Table expected = new Table.TestBuilder() + .column(inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) // left + .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 3) // right + .build()) { + long rowCount = leftKeys.fullJoinRowCount(rightHash); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = leftKeys.fullJoinGatherMaps(rightHash, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testFullHashJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + HashJoin rightHash = new HashJoin(rightKeys, true); + Table expected = new Table.TestBuilder() + .column(inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build()) { + GatherMap[] maps = leftKeys.fullJoinGatherMaps(rightHash); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testFullHashJoinGatherMapsNullsWithCount() { + final int inv = Integer.MIN_VALUE; + try (Table leftKeys = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table rightKeys = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + HashJoin rightHash = new HashJoin(rightKeys, true); + Table expected = new Table.TestBuilder() + .column(inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build()) { + long rowCount = leftKeys.fullJoinRowCount(rightHash); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = leftKeys.fullJoinGatherMaps(rightHash, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testConditionalFullJoinGatherMaps() { final int inv = Integer.MIN_VALUE;