Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Java bindings for conditional join gather maps #8888

Merged
merged 2 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;
import ai.rapids.cudf.ast.CompiledExpression;

import java.io.File;
import java.math.BigDecimal;
Expand Down Expand Up @@ -523,6 +524,26 @@ private static native long[] leftAntiJoin(long leftTable, int[] leftJoinCols, lo
private static native long[] leftAntiJoinGatherMap(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long[] conditionalLeftJoinGatherMaps(long leftTable, long rightTable,
long condition,
boolean compareNullsEqual) throws CudfException;

private static native long[] conditionalInnerJoinGatherMaps(long leftTable, long rightTable,
long condition,
boolean compareNullsEqual) throws CudfException;

private static native long[] conditionalFullJoinGatherMaps(long leftTable, long rightTable,
long condition,
boolean compareNullsEqual) throws CudfException;

private static native long[] conditionalLeftSemiJoinGatherMap(long leftTable, long rightTable,
long condition,
boolean compareNullsEqual) throws CudfException;

private static native long[] conditionalLeftAntiJoinGatherMap(long leftTable, long rightTable,
long condition,
boolean compareNullsEqual) throws CudfException;

private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException;

private static native long[] concatenate(long[] cudfTablePointers) throws CudfException;
Expand Down Expand Up @@ -1969,6 +1990,30 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of a left join between
* two tables when a conditional expression is true. It is assumed this table instance holds
* the columns from the left table, and the table argument represents the columns from the
* right table. Two {@link GatherMap} instances will be returned that can be used to gather
* the left and right tables, respectively, to produce the result of the left join.
* It is the responsibility of the caller to close the resulting gather map instances.
* @param rightTable the right side table of the join in the join
* @param condition conditional expression to evaluate during the join
* @param compareNullsEqual true if null key values should match otherwise false
* @return left and right table gather maps
*/
public GatherMap[] leftJoinGatherMaps(Table rightTable, CompiledExpression condition,
boolean compareNullsEqual) {
if (getNumberOfColumns() != rightTable.getNumberOfColumns()) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightTable.getNumberOfColumns());
}
long[] gatherMapData =
conditionalLeftJoinGatherMaps(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), compareNullsEqual);
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of an inner equi-join between
* two tables. It is assumed this table instance holds the key columns from the left table, and
Expand All @@ -1990,6 +2035,30 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of an inner join between
* two tables when a conditional expression is true. It is assumed this table instance holds
* the columns from the left table, and the table argument represents the columns from the
* right table. Two {@link GatherMap} instances will be returned that can be used to gather
* the left and right tables, respectively, to produce the result of the inner join.
* It is the responsibility of the caller to close the resulting gather map instances.
* @param rightTable the right side table of the join
* @param condition conditional expression to evaluate during the join
* @param compareNullsEqual true if null key values should match otherwise false
* @return left and right table gather maps
*/
public GatherMap[] innerJoinGatherMaps(Table rightTable, CompiledExpression condition,
boolean compareNullsEqual) {
if (getNumberOfColumns() != rightTable.getNumberOfColumns()) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightTable.getNumberOfColumns());
}
long[] gatherMapData =
conditionalInnerJoinGatherMaps(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), compareNullsEqual);
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of an full equi-join between
* two tables. It is assumed this table instance holds the key columns from the left table, and
Expand All @@ -2011,6 +2080,30 @@ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of a full join between
* two tables when a conditional expression is true. It is assumed this table instance holds
* the columns from the left table, and the table argument represents the columns from the
* right table. Two {@link GatherMap} instances will be returned that can be used to gather
* the left and right tables, respectively, to produce the result of the full join.
* It is the responsibility of the caller to close the resulting gather map instances.
* @param rightTable the right side table of the join
* @param condition conditional expression to evaluate during the join
* @param compareNullsEqual true if null key values should match otherwise false
* @return left and right table gather maps
*/
public GatherMap[] fullJoinGatherMaps(Table rightTable, CompiledExpression condition,
boolean compareNullsEqual) {
if (getNumberOfColumns() != rightTable.getNumberOfColumns()) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightTable.getNumberOfColumns());
}
long[] gatherMapData =
conditionalFullJoinGatherMaps(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), compareNullsEqual);
return buildJoinGatherMaps(gatherMapData);
}

private GatherMap buildSemiJoinGatherMap(long[] gatherMapData) {
long bufferSize = gatherMapData[0];
long leftAddr = gatherMapData[1];
Expand Down Expand Up @@ -2039,6 +2132,30 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua
return buildSemiJoinGatherMap(gatherMapData);
}

/**
* Computes the gather map that can be used to manifest the result of a left semi join between
* two tables when a conditional expression is true. It is assumed this table instance holds
* the columns from the left table, and the table argument represents the columns from the
* right table. The {@link GatherMap} instance returned can be used to gather the left table
* to produce the result of the left semi join.
* It is the responsibility of the caller to close the resulting gather map instance.
* @param rightTable the right side table of the join
* @param condition conditional expression to evaluate during the join
* @param compareNullsEqual true if null key values should match otherwise false
* @return left table gather map
*/
public GatherMap leftSemiJoinGatherMap(Table rightTable, CompiledExpression condition,
boolean compareNullsEqual) {
if (getNumberOfColumns() != rightTable.getNumberOfColumns()) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightTable.getNumberOfColumns());
}
long[] gatherMapData =
conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), compareNullsEqual);
return buildSemiJoinGatherMap(gatherMapData);
}

/**
* Computes the gather map that can be used to manifest the result of a left anti-join between
* two tables. It is assumed this table instance holds the key columns from the left table, and
Expand All @@ -2060,6 +2177,30 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua
return buildSemiJoinGatherMap(gatherMapData);
}

/**
* Computes the gather map that can be used to manifest the result of a left anti join between
* two tables when a conditional expression is true. It is assumed this table instance holds
* the columns from the left table, and the table argument represents the columns from the
* right table. The {@link GatherMap} instance returned can be used to gather the left table
* to produce the result of the left anti join.
* It is the responsibility of the caller to close the resulting gather map instance.
* @param rightTable the right side table of the join
* @param condition conditional expression to evaluate during the join
* @param compareNullsEqual true if null key values should match otherwise false
* @return left table gather map
*/
public GatherMap leftAntiJoinGatherMap(Table rightTable, CompiledExpression condition,
boolean compareNullsEqual) {
if (getNumberOfColumns() != rightTable.getNumberOfColumns()) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightTable.getNumberOfColumns());
}
long[] gatherMapData =
conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(),
condition.getNativeHandle(), compareNullsEqual);
return buildSemiJoinGatherMap(gatherMapData);
}

/**
* Convert this table of columns into a row major format that is useful for interacting with other
* systems that do row major processing of the data. Currently only fixed-width column types are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ public synchronized void close() {
isClosed = true;
}

/** Returns the native address of a compiled expression. Intended for internal cudf use only. */
public long getNativeHandle() {
return cleaner.nativeHandle;
}

private static native long compile(byte[] serializedExpression);
private static native long computeColumn(long astHandle, long tableHandle);
private static native void destroy(long handle);
Expand Down
54 changes: 1 addition & 53 deletions java/src/main/native/src/CompiledExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,59 +26,7 @@
#include <cudf/types.hpp>

#include "cudf_jni_apis.hpp"

namespace cudf {
namespace jni {
namespace ast {

/**
* A class to capture all of the resources associated with a compiled AST expression.
* AST nodes do not own their child nodes, so every node in the expression tree
* must be explicitly tracked in order to free the underlying resources for each node.
*
* This should be cleaned up a bit after the libcudf AST refactoring in
* https://github.com/rapidsai/cudf/pull/8815 when a virtual destructor is added to the
* base AST node type. Then we do not have to track every AST node type separately.
*/
class compiled_expr {
/** All literal nodes within the expression tree */
std::vector<std::unique_ptr<cudf::ast::literal>> literals;

/** All column reference nodes within the expression tree */
std::vector<std::unique_ptr<cudf::ast::column_reference>> column_refs;

/** All expression nodes within the expression tree */
std::vector<std::unique_ptr<cudf::ast::expression>> expressions;

/** GPU scalar instances that correspond to literal nodes */
std::vector<std::unique_ptr<cudf::scalar>> scalars;

public:
cudf::ast::literal &add_literal(std::unique_ptr<cudf::ast::literal> literal_ptr,
std::unique_ptr<cudf::scalar> scalar_ptr) {
literals.push_back(std::move(literal_ptr));
scalars.push_back(std::move(scalar_ptr));
return *literals.back();
}

cudf::ast::column_reference &
add_column_ref(std::unique_ptr<cudf::ast::column_reference> ref_ptr) {
column_refs.push_back(std::move(ref_ptr));
return *column_refs.back();
}

cudf::ast::expression &add_expression(std::unique_ptr<cudf::ast::expression> expr_ptr) {
expressions.push_back(std::move(expr_ptr));
return *expressions.back();
}

/** Return the expression node at the top of the tree */
cudf::ast::expression &get_top_expression() const { return *expressions.back(); }
};

} // namespace ast
} // namespace jni
} // namespace cudf
#include "jni_compiled_expr.hpp"

namespace {

Expand Down
Loading