Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java bindings for left outer distinct join #15154

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ private static native long[] merge(long[] tableHandles, int[] sortKeyIndexes,
private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long[] leftDistinctJoinGatherMap(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException;
Expand Down Expand Up @@ -2949,6 +2952,33 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes a gather map that can be used to manifest the result of a left equi-join between
* two tables where the right table is guaranteed to not contain any duplicated join keys.
* The left table can be used as-is to produce the left table columns resulting from the join,
* i.e.: left table ordering is preserved in the join result, so no gather map is required for
* the left table. The resulting gather map can be applied to the right table to produce the
* right table columns resulting from the join. It is assumed this table instance holds the
* key columns from the left table, and the table argument represents the key columns from the
* right table. A {@link GatherMap} instance will be returned that can be used to gather the
* right table and that result combined with the left table to produce a left outer join result.
*
* It is the responsibility of the caller to close the resulting gather map instance.
*
* @param rightKeys join key columns from the right table
* @param compareNullsEqual true if null key values should match otherwise false
* @return right table gather map
*/
public GatherMap leftDistinctJoinGatherMap(Table rightKeys, boolean compareNullsEqual) {
if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) {
throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightKeys.getNumberOfColumns());
}
long[] gatherMapData =
leftDistinctJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
* Computes the number of rows resulting from a left equi-join between two tables.
* It is assumed this table instance holds the key columns from the left table, and the
Expand Down Expand Up @@ -3576,7 +3606,7 @@ public static GatherMap[] mixedFullJoinGatherMaps(Table leftKeys, Table rightKey
return buildJoinGatherMaps(gatherMapData);
}

private static GatherMap buildSemiJoinGatherMap(long[] gatherMapData) {
private static GatherMap buildSingleJoinGatherMap(long[] gatherMapData) {
long bufferSize = gatherMapData[0];
long leftAddr = gatherMapData[1];
long leftHandle = gatherMapData[2];
Expand All @@ -3601,7 +3631,7 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua
}
long[] gatherMapData =
leftSemiJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3634,7 +3664,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand All @@ -3659,7 +3689,7 @@ public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftSemiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), outputRowCount);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3716,7 +3746,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe
leftConditional.getNativeView(), rightConditional.getNativeView(),
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3752,7 +3782,7 @@ public static GatherMap mixedLeftSemiJoinGatherMap(Table leftKeys, Table rightKe
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL,
joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand All @@ -3773,7 +3803,7 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua
}
long[] gatherMapData =
leftAntiJoinGatherMap(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3806,7 +3836,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand All @@ -3831,7 +3861,7 @@ public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable,
long[] gatherMapData =
conditionalLeftAntiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), outputRowCount);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3888,7 +3918,7 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
leftConditional.getNativeView(), rightConditional.getNativeView(),
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL);
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down Expand Up @@ -3924,7 +3954,7 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
condition.getNativeHandle(),
nullEquality == NullEquality.EQUAL,
joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView());
return buildSemiJoinGatherMap(gatherMapData);
return buildSingleJoinGatherMap(gatherMapData);
}

/**
Expand Down
18 changes: 18 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps(
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMap(
JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) {
return cudf::jni::join_gather_single_map(
env, j_left_keys, j_right_keys, compare_nulls_equal,
[](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) {
auto has_nulls = cudf::has_nested_nulls(left) || cudf::has_nested_nulls(right) ?
cudf::nullable_join::YES :
cudf::nullable_join::NO;
if (cudf::detail::has_nested_columns(right)) {
cudf::distinct_hash_join<cudf::has_nested::YES> hash(right, left, has_nulls, nulleq);
return hash.left_join();
} else {
cudf::distinct_hash_join<cudf::has_nested::NO> hash(right, left, has_nulls, nulleq);
return hash.left_join();
}
});
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv *env, jclass,
jlong j_left_table,
jlong j_right_hash_join) {
Expand Down
101 changes: 101 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,107 @@ void testLeftJoinGatherMapsNulls() {
}
}

private void checkLeftDistinctJoin(Table leftKeys, Table rightKeys, ColumnView expected,
boolean compareNullsEqual) {
try (GatherMap map = leftKeys.leftDistinctJoinGatherMap(rightKeys, compareNullsEqual)) {
int numRows = (int) expected.getRowCount();
assertEquals(numRows, map.getRowCount());
try (ColumnView view = map.toColumnView(0, numRows)) {
assertColumnsAreEqual(expected, view);
}
}
}

@Test
void testLeftDistinctJoinGatherMaps() {
final int inv = Integer.MIN_VALUE;
try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8, 6).build();
Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build();
ColumnVector expected = ColumnVector.fromInts(inv, inv, 2, inv, inv, inv, inv, 0, 1, 3, 0)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, false);
}
}

@Test
void testLeftDistinctJoinGatherMapsWithNested() {
final int inv = Integer.MIN_VALUE;
StructType structType = new StructType(false,
new BasicType(false, DType.STRING),
new BasicType(false, DType.INT32));
StructData[] leftData = new StructData[]{
new StructData("abc", 1),
new StructData("xyz", 1),
new StructData("abc", 2),
new StructData("xyz", 2),
new StructData("abc", 1),
new StructData("abc", 3),
new StructData("xyz", 3)
};
StructData[] rightData = new StructData[]{
new StructData("abc", 1),
new StructData("xyz", 4),
new StructData("xyz", 2),
new StructData("abc", -1),
};
try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build();
Table rightKeys = new Table.TestBuilder().column(structType, rightData).build();
ColumnVector expected = ColumnVector.fromInts(0, inv, inv, 2, 0, inv, inv)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, false);
}
}

@Test
void testLeftDistinctJoinGatherMapsNullsEqual() {
final int inv = Integer.MIN_VALUE;
try (Table leftKeys = new Table.TestBuilder()
.column(2, 3, 9, 0, 1, 7, 4, null, null, 8)
.build();
Table rightKeys = new Table.TestBuilder()
.column(null, 9, 8, 10, 32)
.build();
ColumnVector expected = ColumnVector.fromInts(inv, inv, 1, inv, inv, inv, inv, 0, 0, 2)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, true);
}
}

@Test
void testLeftDistinctJoinGatherMapsWithNestedNullsEqual() {
final int inv = Integer.MIN_VALUE;
StructType structType = new StructType(true,
new BasicType(true, DType.STRING),
new BasicType(true, DType.INT32));
StructData[] leftData = new StructData[]{
new StructData("abc", 1),
null,
new StructData("xyz", 1),
new StructData("abc", 2),
new StructData("xyz", null),
null,
new StructData("abc", 1),
new StructData("abc", 3),
new StructData("xyz", 3),
new StructData(null, null),
new StructData(null, 1)
};
StructData[] rightData = new StructData[]{
null,
new StructData("abc", 1),
new StructData("xyz", 4),
new StructData("xyz", 2),
new StructData(null, null),
new StructData(null, 2),
new StructData(null, 1),
new StructData("xyz", null),
new StructData("abc", null),
new StructData("abc", -1)
};
try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build();
Table rightKeys = new Table.TestBuilder().column(structType, rightData).build();
ColumnVector expected = ColumnVector.fromInts(1, 0, inv, inv, 7, 0, 1, inv, inv, 4, 6)) {
checkLeftDistinctJoin(leftKeys, rightKeys, expected, true);
}
}

@Test
void testLeftHashJoinGatherMaps() {
final int inv = Integer.MIN_VALUE;
Expand Down
Loading