-
Notifications
You must be signed in to change notification settings - Fork 919
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Java bindings for AST transform (#8846)
This adds Java bindings to the AST transform operation that computes a new column from an input table using an AST expression. Authors: - Jason Lowe (https://github.com/jlowe) Approvers: - Robert (Bobby) Evans (https://github.com/revans2) URL: #8846
- Loading branch information
Showing
15 changed files
with
1,794 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.rapids.cudf.ast; | ||
|
||
import java.nio.ByteBuffer; | ||
|
||
/** Base class of every node in an AST */ | ||
abstract class AstNode { | ||
/** | ||
* Enumeration for the types of AST nodes that can appear in a serialized AST. | ||
* NOTE: This must be kept in sync with the `jni_serialized_node_type` in CompiledExpression.cpp! | ||
*/ | ||
protected enum NodeType { | ||
VALID_LITERAL(0), | ||
NULL_LITERAL(1), | ||
COLUMN_REFERENCE(2), | ||
UNARY_EXPRESSION(3), | ||
BINARY_EXPRESSION(4); | ||
|
||
private final byte nativeId; | ||
|
||
NodeType(int nativeId) { | ||
this.nativeId = (byte) nativeId; | ||
assert this.nativeId == nativeId; | ||
} | ||
|
||
/** Get the size in bytes to serialize this node type */ | ||
int getSerializedSize() { | ||
return Byte.BYTES; | ||
} | ||
|
||
/** Serialize this node type to the specified buffer */ | ||
void serialize(ByteBuffer bb) { | ||
bb.put(nativeId); | ||
} | ||
} | ||
|
||
/** Get the size in bytes of the serialized form of this node and all child nodes */ | ||
abstract int getSerializedSize(); | ||
|
||
/** | ||
* Serialize this node and all child nodes. | ||
* @param bb buffer to receive the serialized data | ||
*/ | ||
abstract void serialize(ByteBuffer bb); | ||
} |
48 changes: 48 additions & 0 deletions
48
java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.rapids.cudf.ast; | ||
|
||
import java.nio.ByteBuffer; | ||
|
||
/** A binary expression consisting of an operator and two operands. */ | ||
public class BinaryExpression extends Expression { | ||
private final BinaryOperator op; | ||
private final AstNode leftInput; | ||
private final AstNode rightInput; | ||
|
||
public BinaryExpression(BinaryOperator op, AstNode leftInput, AstNode rightInput) { | ||
this.op = op; | ||
this.leftInput = leftInput; | ||
this.rightInput = rightInput; | ||
} | ||
|
||
@Override | ||
int getSerializedSize() { | ||
return NodeType.BINARY_EXPRESSION.getSerializedSize() + | ||
op.getSerializedSize() + | ||
leftInput.getSerializedSize() + | ||
rightInput.getSerializedSize(); | ||
} | ||
|
||
@Override | ||
void serialize(ByteBuffer bb) { | ||
NodeType.BINARY_EXPRESSION.serialize(bb); | ||
op.serialize(bb); | ||
leftInput.serialize(bb); | ||
rightInput.serialize(bb); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.rapids.cudf.ast; | ||
|
||
import java.nio.ByteBuffer; | ||
|
||
/** | ||
* Enumeration of AST operations that can appear in a binary expression. | ||
* NOTE: This must be kept in sync with `jni_to_binary_operator` in CompiledExpression.cpp! | ||
*/ | ||
public enum BinaryOperator { | ||
ADD(0), // operator + | ||
SUB(1), // operator - | ||
MUL(2), // operator * | ||
DIV(3), // operator / using common type of lhs and rhs | ||
TRUE_DIV(4), // operator / after promoting type to floating point | ||
FLOOR_DIV(5), // operator / after promoting to 64 bit floating point and then flooring the result | ||
MOD(6), // operator % | ||
PYMOD(7), // operator % but following python's sign rules for negatives | ||
POW(8), // lhs ^ rhs | ||
EQUAL(9), // operator == | ||
NOT_EQUAL(10), // operator != | ||
LESS(11), // operator < | ||
GREATER(12), // operator > | ||
LESS_EQUAL(13), // operator <= | ||
GREATER_EQUAL(14), // operator >= | ||
BITWISE_AND(15), // operator & | ||
BITWISE_OR(16), // operator | | ||
BITWISE_XOR(17), // operator ^ | ||
LOGICAL_AND(18), // operator && | ||
LOGICAL_OR(19); // operator || | ||
|
||
private final byte nativeId; | ||
|
||
BinaryOperator(int nativeId) { | ||
this.nativeId = (byte) nativeId; | ||
assert this.nativeId == nativeId; | ||
} | ||
|
||
/** Get the size in bytes to serialize this operator */ | ||
int getSerializedSize() { | ||
return Byte.BYTES; | ||
} | ||
|
||
/** Serialize this operator to the specified buffer */ | ||
void serialize(ByteBuffer bb) { | ||
bb.put(nativeId); | ||
} | ||
} |
51 changes: 51 additions & 0 deletions
51
java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.rapids.cudf.ast; | ||
|
||
import java.nio.ByteBuffer; | ||
|
||
/** A reference to a column in an input table. */ | ||
public final class ColumnReference extends AstNode { | ||
private final int columnIndex; | ||
private final TableReference tableSource; | ||
|
||
/** Construct a column reference to either the only or leftmost input table */ | ||
public ColumnReference(int columnIndex) { | ||
this(columnIndex, TableReference.LEFT); | ||
} | ||
|
||
/** Construct a column reference to the specified column index in the specified table */ | ||
public ColumnReference(int columnIndex, TableReference tableSource) { | ||
this.columnIndex = columnIndex; | ||
this.tableSource = tableSource; | ||
} | ||
|
||
@Override | ||
int getSerializedSize() { | ||
// node type + table ref + column index | ||
return NodeType.COLUMN_REFERENCE.getSerializedSize() + | ||
tableSource.getSerializedSize() + | ||
Integer.BYTES; | ||
} | ||
|
||
@Override | ||
void serialize(ByteBuffer bb) { | ||
NodeType.COLUMN_REFERENCE.serialize(bb); | ||
tableSource.serialize(bb); | ||
bb.putInt(columnIndex); | ||
} | ||
} |
100 changes: 100 additions & 0 deletions
100
java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.rapids.cudf.ast; | ||
|
||
import ai.rapids.cudf.ColumnVector; | ||
import ai.rapids.cudf.MemoryCleaner; | ||
import ai.rapids.cudf.Table; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
/** This class wraps a native compiled AST and must be closed to avoid native memory leaks. */ | ||
public class CompiledExpression implements AutoCloseable { | ||
private static final Logger log = LoggerFactory.getLogger(CompiledExpression.class); | ||
|
||
private static class CompiledExpressionCleaner extends MemoryCleaner.Cleaner { | ||
private long nativeHandle; | ||
|
||
CompiledExpressionCleaner(long nativeHandle) { | ||
this.nativeHandle = nativeHandle; | ||
} | ||
|
||
@Override | ||
protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) { | ||
long origAddress = nativeHandle; | ||
boolean neededCleanup = nativeHandle != 0; | ||
if (neededCleanup) { | ||
try { | ||
destroy(nativeHandle); | ||
} finally { | ||
nativeHandle = 0; | ||
} | ||
if (logErrorIfNotClean) { | ||
log.error("AN AST COMPILED EXPRESSION WAS LEAKED (ID: " + | ||
id + " " + Long.toHexString(origAddress)); | ||
} | ||
} | ||
return neededCleanup; | ||
} | ||
|
||
@Override | ||
public boolean isClean() { | ||
return nativeHandle == 0; | ||
} | ||
} | ||
|
||
private final CompiledExpressionCleaner cleaner; | ||
private boolean isClosed = false; | ||
|
||
/** Construct a compiled expression from a serialized AST */ | ||
CompiledExpression(byte[] serializedExpression) { | ||
this(compile(serializedExpression)); | ||
} | ||
|
||
/** Construct a compiled expression from a native compiled AST pointer */ | ||
CompiledExpression(long nativeHandle) { | ||
this.cleaner = new CompiledExpressionCleaner(nativeHandle); | ||
MemoryCleaner.register(this, cleaner); | ||
cleaner.addRef(); | ||
} | ||
|
||
/** | ||
* Compute a new column by applying this AST expression to the specified table. All | ||
* {@link ColumnReference} instances within the expression will use the sole input table, | ||
* even if they try to specify a non-existent table, e.g.: {@link TableReference#RIGHT}. | ||
* @param table input table for this expression | ||
* @return new column computed from this expression applied to the input table | ||
*/ | ||
public ColumnVector computeColumn(Table table) { | ||
return new ColumnVector(computeColumn(cleaner.nativeHandle, table.getNativeView())); | ||
} | ||
|
||
@Override | ||
public synchronized void close() { | ||
cleaner.delRef(); | ||
if (isClosed) { | ||
cleaner.logRefCountDebug("double free " + this); | ||
throw new IllegalStateException("Close called too many times " + this); | ||
} | ||
cleaner.clean(false); | ||
isClosed = true; | ||
} | ||
|
||
private static native long compile(byte[] serializedExpression); | ||
private static native long computeColumn(long astHandle, long tableHandle); | ||
private static native void destroy(long handle); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package ai.rapids.cudf.ast; | ||
|
||
import java.nio.ByteBuffer; | ||
import java.nio.ByteOrder; | ||
|
||
/** Base class of every AST expression. */ | ||
public abstract class Expression extends AstNode { | ||
public CompiledExpression compile() { | ||
int size = getSerializedSize(); | ||
ByteBuffer bb = ByteBuffer.allocate(size); | ||
bb.order(ByteOrder.nativeOrder()); | ||
serialize(bb); | ||
return new CompiledExpression(bb.array()); | ||
} | ||
} |
Oops, something went wrong.