From ac8cc53abef166187cfc9b8500b09612f7abc8e1 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 28 Jul 2021 14:10:15 -0500 Subject: [PATCH] Add Java bindings for AST transform (#8846) This adds Java bindings to the AST transform operation that computes a new column from an input table using an AST expression. Authors: - Jason Lowe (https://github.com/jlowe) Approvers: - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/8846 --- .../java/ai/rapids/cudf/MemoryCleaner.java | 7 +- java/src/main/java/ai/rapids/cudf/Table.java | 2 +- .../main/java/ai/rapids/cudf/ast/AstNode.java | 60 ++ .../ai/rapids/cudf/ast/BinaryExpression.java | 48 ++ .../ai/rapids/cudf/ast/BinaryOperator.java | 63 ++ .../ai/rapids/cudf/ast/ColumnReference.java | 51 ++ .../rapids/cudf/ast/CompiledExpression.java | 100 +++ .../java/ai/rapids/cudf/ast/Expression.java | 31 + .../main/java/ai/rapids/cudf/ast/Literal.java | 264 ++++++++ .../ai/rapids/cudf/ast/TableReference.java | 47 ++ .../ai/rapids/cudf/ast/UnaryExpression.java | 44 ++ .../ai/rapids/cudf/ast/UnaryOperator.java | 65 ++ java/src/main/native/CMakeLists.txt | 1 + .../main/native/src/CompiledExpression.cpp | 437 +++++++++++++ .../cudf/ast/CompiledExpressionTest.java | 576 ++++++++++++++++++ 15 files changed, 1794 insertions(+), 2 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/ast/AstNode.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/Expression.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/Literal.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/TableReference.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java create mode 100644 java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java create mode 100644 java/src/main/native/src/CompiledExpression.cpp create mode 100644 java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java diff --git a/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java b/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java index aa084ad7eef..4bf38543a2d 100644 --- a/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java +++ b/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ package ai.rapids.cudf; +import ai.rapids.cudf.ast.CompiledExpression; import ai.rapids.cudf.nvcomp.BatchedLZ4Decompressor; import ai.rapids.cudf.nvcomp.Decompressor; import org.slf4j.Logger; @@ -272,6 +273,10 @@ static void register(CuFileHandle handle, Cleaner cleaner) { all.add(new CleanerWeakReference(handle, cleaner, collected, false)); } + public static void register(CompiledExpression expr, Cleaner cleaner) { + all.add(new CleanerWeakReference(expr, cleaner, collected, false)); + } + /** * This is not 100% perfect and we can still run into situations where RMM buffers were not * collected and this returns false because of thread race conditions. This is just a best effort. diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index ea261410585..627a2a36e9e 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -112,7 +112,7 @@ ColumnVector[] getColumns() { } /** Return the native table view handle for this table */ - long getNativeView() { + public long getNativeView() { return nativeHandle; } diff --git a/java/src/main/java/ai/rapids/cudf/ast/AstNode.java b/java/src/main/java/ai/rapids/cudf/ast/AstNode.java new file mode 100644 index 00000000000..78cf39b05d2 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/AstNode.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; + +/** Base class of every node in an AST */ +abstract class AstNode { + /** + * Enumeration for the types of AST nodes that can appear in a serialized AST. + * NOTE: This must be kept in sync with the `jni_serialized_node_type` in CompiledExpression.cpp! + */ + protected enum NodeType { + VALID_LITERAL(0), + NULL_LITERAL(1), + COLUMN_REFERENCE(2), + UNARY_EXPRESSION(3), + BINARY_EXPRESSION(4); + + private final byte nativeId; + + NodeType(int nativeId) { + this.nativeId = (byte) nativeId; + assert this.nativeId == nativeId; + } + + /** Get the size in bytes to serialize this node type */ + int getSerializedSize() { + return Byte.BYTES; + } + + /** Serialize this node type to the specified buffer */ + void serialize(ByteBuffer bb) { + bb.put(nativeId); + } + } + + /** Get the size in bytes of the serialized form of this node and all child nodes */ + abstract int getSerializedSize(); + + /** + * Serialize this node and all child nodes. + * @param bb buffer to receive the serialized data + */ + abstract void serialize(ByteBuffer bb); +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java b/java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java new file mode 100644 index 00000000000..ed4f95b01e1 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; + +/** A binary expression consisting of an operator and two operands. */ +public class BinaryExpression extends Expression { + private final BinaryOperator op; + private final AstNode leftInput; + private final AstNode rightInput; + + public BinaryExpression(BinaryOperator op, AstNode leftInput, AstNode rightInput) { + this.op = op; + this.leftInput = leftInput; + this.rightInput = rightInput; + } + + @Override + int getSerializedSize() { + return NodeType.BINARY_EXPRESSION.getSerializedSize() + + op.getSerializedSize() + + leftInput.getSerializedSize() + + rightInput.getSerializedSize(); + } + + @Override + void serialize(ByteBuffer bb) { + NodeType.BINARY_EXPRESSION.serialize(bb); + op.serialize(bb); + leftInput.serialize(bb); + rightInput.serialize(bb); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java new file mode 100644 index 00000000000..12e4d985658 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; + +/** + * Enumeration of AST operations that can appear in a binary expression. + * NOTE: This must be kept in sync with `jni_to_binary_operator` in CompiledExpression.cpp! + */ +public enum BinaryOperator { + ADD(0), // operator + + SUB(1), // operator - + MUL(2), // operator * + DIV(3), // operator / using common type of lhs and rhs + TRUE_DIV(4), // operator / after promoting type to floating point + FLOOR_DIV(5), // operator / after promoting to 64 bit floating point and then flooring the result + MOD(6), // operator % + PYMOD(7), // operator % but following python's sign rules for negatives + POW(8), // lhs ^ rhs + EQUAL(9), // operator == + NOT_EQUAL(10), // operator != + LESS(11), // operator < + GREATER(12), // operator > + LESS_EQUAL(13), // operator <= + GREATER_EQUAL(14), // operator >= + BITWISE_AND(15), // operator & + BITWISE_OR(16), // operator | + BITWISE_XOR(17), // operator ^ + LOGICAL_AND(18), // operator && + LOGICAL_OR(19); // operator || + + private final byte nativeId; + + BinaryOperator(int nativeId) { + this.nativeId = (byte) nativeId; + assert this.nativeId == nativeId; + } + + /** Get the size in bytes to serialize this operator */ + int getSerializedSize() { + return Byte.BYTES; + } + + /** Serialize this operator to the specified buffer */ + void serialize(ByteBuffer bb) { + bb.put(nativeId); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java b/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java new file mode 100644 index 00000000000..34e4064e23b --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; + +/** A reference to a column in an input table. */ +public final class ColumnReference extends AstNode { + private final int columnIndex; + private final TableReference tableSource; + + /** Construct a column reference to either the only or leftmost input table */ + public ColumnReference(int columnIndex) { + this(columnIndex, TableReference.LEFT); + } + + /** Construct a column reference to the specified column index in the specified table */ + public ColumnReference(int columnIndex, TableReference tableSource) { + this.columnIndex = columnIndex; + this.tableSource = tableSource; + } + + @Override + int getSerializedSize() { + // node type + table ref + column index + return NodeType.COLUMN_REFERENCE.getSerializedSize() + + tableSource.getSerializedSize() + + Integer.BYTES; + } + + @Override + void serialize(ByteBuffer bb) { + NodeType.COLUMN_REFERENCE.serialize(bb); + tableSource.serialize(bb); + bb.putInt(columnIndex); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java b/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java new file mode 100644 index 00000000000..0d2a0052e29 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.MemoryCleaner; +import ai.rapids.cudf.Table; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** This class wraps a native compiled AST and must be closed to avoid native memory leaks. */ +public class CompiledExpression implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(CompiledExpression.class); + + private static class CompiledExpressionCleaner extends MemoryCleaner.Cleaner { + private long nativeHandle; + + CompiledExpressionCleaner(long nativeHandle) { + this.nativeHandle = nativeHandle; + } + + @Override + protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) { + long origAddress = nativeHandle; + boolean neededCleanup = nativeHandle != 0; + if (neededCleanup) { + try { + destroy(nativeHandle); + } finally { + nativeHandle = 0; + } + if (logErrorIfNotClean) { + log.error("AN AST COMPILED EXPRESSION WAS LEAKED (ID: " + + id + " " + Long.toHexString(origAddress)); + } + } + return neededCleanup; + } + + @Override + public boolean isClean() { + return nativeHandle == 0; + } + } + + private final CompiledExpressionCleaner cleaner; + private boolean isClosed = false; + + /** Construct a compiled expression from a serialized AST */ + CompiledExpression(byte[] serializedExpression) { + this(compile(serializedExpression)); + } + + /** Construct a compiled expression from a native compiled AST pointer */ + CompiledExpression(long nativeHandle) { + this.cleaner = new CompiledExpressionCleaner(nativeHandle); + MemoryCleaner.register(this, cleaner); + cleaner.addRef(); + } + + /** + * Compute a new column by applying this AST expression to the specified table. All + * {@link ColumnReference} instances within the expression will use the sole input table, + * even if they try to specify a non-existent table, e.g.: {@link TableReference#RIGHT}. + * @param table input table for this expression + * @return new column computed from this expression applied to the input table + */ + public ColumnVector computeColumn(Table table) { + return new ColumnVector(computeColumn(cleaner.nativeHandle, table.getNativeView())); + } + + @Override + public synchronized void close() { + cleaner.delRef(); + if (isClosed) { + cleaner.logRefCountDebug("double free " + this); + throw new IllegalStateException("Close called too many times " + this); + } + cleaner.clean(false); + isClosed = true; + } + + private static native long compile(byte[] serializedExpression); + private static native long computeColumn(long astHandle, long tableHandle); + private static native void destroy(long handle); +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/Expression.java b/java/src/main/java/ai/rapids/cudf/ast/Expression.java new file mode 100644 index 00000000000..8d391298cef --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/Expression.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** Base class of every AST expression. */ +public abstract class Expression extends AstNode { + public CompiledExpression compile() { + int size = getSerializedSize(); + ByteBuffer bb = ByteBuffer.allocate(size); + bb.order(ByteOrder.nativeOrder()); + serialize(bb); + return new CompiledExpression(bb.array()); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/Literal.java b/java/src/main/java/ai/rapids/cudf/ast/Literal.java new file mode 100644 index 00000000000..be306cd99c4 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/Literal.java @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import ai.rapids.cudf.DType; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** A literal value in an AST expression. */ +public final class Literal extends AstNode { + private final DType type; + private final byte[] serializedValue; + + /** Construct a null literal of the specified type. */ + public static Literal ofNull(DType type) { + return new Literal(type, null); + } + + /** Construct a boolean literal with the specified value. */ + public static Literal ofBoolean(boolean value) { + return new Literal(DType.BOOL8, new byte[] { value ? (byte) 1 : (byte) 0 }); + } + + /** Construct a boolean literal with the specified value or null. */ + public static Literal ofBoolean(Boolean value) { + if (value == null) { + return ofNull(DType.BOOL8); + } + return ofBoolean(value.booleanValue()); + } + + /** Construct a byte literal with the specified value. */ + public static Literal ofByte(byte value) { + return new Literal(DType.INT8, new byte[] { value }); + } + + /** Construct a byte literal with the specified value or null. */ + public static Literal ofByte(Byte value) { + if (value == null) { + return ofNull(DType.INT8); + } + return ofByte(value.byteValue()); + } + + /** Construct a short literal with the specified value. */ + public static Literal ofShort(short value) { + byte[] serializedValue = new byte[Short.BYTES]; + ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putShort(value); + return new Literal(DType.INT16, serializedValue); + } + + /** Construct a short literal with the specified value or null. */ + public static Literal ofShort(Short value) { + if (value == null) { + return ofNull(DType.INT16); + } + return ofShort(value.shortValue()); + } + + /** Construct an integer literal with the specified value. */ + public static Literal ofInt(int value) { + return ofIntBasedType(DType.INT32, value); + } + + /** Construct an integer literal with the specified value or null. */ + public static Literal ofInt(Integer value) { + if (value == null) { + return ofNull(DType.INT32); + } + return ofInt(value.intValue()); + } + + /** Construct a long literal with the specified value. */ + public static Literal ofLong(long value) { + return ofLongBasedType(DType.INT64, value); + } + + /** Construct a long literal with the specified value or null. */ + public static Literal ofLong(Long value) { + if (value == null) { + return ofNull(DType.INT64); + } + return ofLong(value.longValue()); + } + + /** Construct a float literal with the specified value. */ + public static Literal ofFloat(float value) { + byte[] serializedValue = new byte[Float.BYTES]; + ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putFloat(value); + return new Literal(DType.FLOAT32, serializedValue); + } + + /** Construct a float literal with the specified value or null. */ + public static Literal ofFloat(Float value) { + if (value == null) { + return ofNull(DType.FLOAT32); + } + return ofFloat(value.floatValue()); + } + + /** Construct a double literal with the specified value. */ + public static Literal ofDouble(double value) { + byte[] serializedValue = new byte[Double.BYTES]; + ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putDouble(value); + return new Literal(DType.FLOAT64, serializedValue); + } + + /** Construct a double literal with the specified value or null. */ + public static Literal ofDouble(Double value) { + if (value == null) { + return ofNull(DType.FLOAT64); + } + return ofDouble(value.doubleValue()); + } + + /** Construct a timestamp days literal with the specified value. */ + public static Literal ofTimestampDaysFromInt(int value) { + return ofIntBasedType(DType.TIMESTAMP_DAYS, value); + } + + /** Construct a timestamp days literal with the specified value or null. */ + public static Literal ofTimestampDaysFromInt(Integer value) { + if (value == null) { + return ofNull(DType.TIMESTAMP_DAYS); + } + return ofTimestampDaysFromInt(value.intValue()); + } + + /** Construct a long-based timestamp literal with the specified value. */ + public static Literal ofTimestampFromLong(DType type, long value) { + if (!type.isTimestampType()) { + throw new IllegalArgumentException("type is not a timestamp: " + type); + } + if (type.equals(DType.TIMESTAMP_DAYS)) { + int intValue = (int)value; + if (value != intValue) { + throw new IllegalArgumentException("value too large for type " + type + ": " + value); + } + return ofTimestampDaysFromInt(intValue); + } + return ofLongBasedType(type, value); + } + + /** Construct a long-based timestamp literal with the specified value or null. */ + public static Literal ofTimestampFromLong(DType type, Long value) { + if (value == null) { + return ofNull(type); + } + return ofTimestampFromLong(type, value.longValue()); + } + + /** Construct a duration days literal with the specified value. */ + public static Literal ofDurationDaysFromInt(int value) { + return ofIntBasedType(DType.DURATION_DAYS, value); + } + + /** Construct a duration days literal with the specified value or null. */ + public static Literal ofDurationDaysFromInt(Integer value) { + if (value == null) { + return ofNull(DType.DURATION_DAYS); + } + return ofDurationDaysFromInt(value.intValue()); + } + + /** Construct a long-based duration literal with the specified value. */ + public static Literal ofDurationFromLong(DType type, long value) { + if (!type.isDurationType()) { + throw new IllegalArgumentException("type is not a timestamp: " + type); + } + if (type.equals(DType.DURATION_DAYS)) { + int intValue = (int)value; + if (value != intValue) { + throw new IllegalArgumentException("value too large for type " + type + ": " + value); + } + return ofDurationDaysFromInt(intValue); + } + return ofLongBasedType(type, value); + } + + /** Construct a long-based duration literal with the specified value or null. */ + public static Literal ofDurationFromLong(DType type, Long value) { + if (value == null) { + return ofNull(type); + } + return ofDurationFromLong(type, value.longValue()); + } + + Literal(DType type, byte[] serializedValue) { + this.type = type; + this.serializedValue = serializedValue; + } + + @Override + int getSerializedSize() { + NodeType nodeType = serializedValue != null + ? NodeType.VALID_LITERAL : NodeType.NULL_LITERAL; + int size = nodeType.getSerializedSize() + getDataTypeSerializedSize(); + if (serializedValue != null) { + size += serializedValue.length; + } + return size; + } + + @Override + void serialize(ByteBuffer bb) { + NodeType nodeType = serializedValue != null + ? NodeType.VALID_LITERAL : NodeType.NULL_LITERAL; + nodeType.serialize(bb); + serializeDataType(bb); + if (serializedValue != null) { + bb.put(serializedValue); + } + } + + private int getDataTypeSerializedSize() { + int nativeTypeId = type.getTypeId().getNativeId(); + assert nativeTypeId == (byte) nativeTypeId : "Type ID does not fit in a byte"; + if (type.isDecimalType()) { + assert type.getScale() == (byte) type.getScale() : "Decimal scale does not fit in a byte"; + return 2; + } + return 1; + } + + private void serializeDataType(ByteBuffer bb) { + byte nativeTypeId = (byte) type.getTypeId().getNativeId(); + assert nativeTypeId == type.getTypeId().getNativeId() : "DType ID does not fit in a byte"; + bb.put(nativeTypeId); + if (type.isDecimalType()) { + byte scale = (byte) type.getScale(); + assert scale == (byte) type.getScale() : "Decimal scale does not fit in a byte"; + bb.put(scale); + } + } + + private static Literal ofIntBasedType(DType type, int value) { + assert type.getSizeInBytes() == Integer.BYTES; + byte[] serializedValue = new byte[Integer.BYTES]; + ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putInt(value); + return new Literal(type, serializedValue); + } + + private static Literal ofLongBasedType(DType type, long value) { + assert type.getSizeInBytes() == Long.BYTES; + byte[] serializedValue = new byte[Long.BYTES]; + ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putLong(value); + return new Literal(type, serializedValue); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/TableReference.java b/java/src/main/java/ai/rapids/cudf/ast/TableReference.java new file mode 100644 index 00000000000..12255779a49 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/TableReference.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + + +import java.nio.ByteBuffer; + +/** + * Enumeration of tables that can be referenced in an AST. + * NOTE: This must be kept in sync with `jni_to_table_reference` code in CompiledExpression.cpp! + */ +public enum TableReference { + LEFT(0), + RIGHT(1); + // OUTPUT is an AST implementation detail and should not appear in user-built expressions. + + private final byte nativeId; + + TableReference(int nativeId) { + this.nativeId = (byte) nativeId; + assert this.nativeId == nativeId; + } + + /** Get the size in bytes to serialize this table reference */ + int getSerializedSize() { + return Byte.BYTES; + } + + /** Serialize this table reference to the specified buffer */ + void serialize(ByteBuffer bb) { + bb.put(nativeId); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java b/java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java new file mode 100644 index 00000000000..fa8e70266ac --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; + +/** A unary expression consisting of an operator and an operand. */ +public final class UnaryExpression extends Expression { + private final UnaryOperator op; + private final AstNode input; + + public UnaryExpression(UnaryOperator op, AstNode input) { + this.op = op; + this.input = input; + } + + @Override + int getSerializedSize() { + return NodeType.UNARY_EXPRESSION.getSerializedSize() + + op.getSerializedSize() + + input.getSerializedSize(); + } + + @Override + void serialize(ByteBuffer bb) { + NodeType.UNARY_EXPRESSION.serialize(bb); + op.serialize(bb); + input.serialize(bb); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java new file mode 100644 index 00000000000..c3f193d06b4 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import java.nio.ByteBuffer; + +/** + * Enumeration of AST operations that can appear in a unary expression. + * NOTE: This must be kept in sync with `jni_to_unary_operator` in CompiledExpression.cpp! + */ +public enum UnaryOperator { + IDENTITY(0), // Identity function + SIN(1), // Trigonometric sine + COS(2), // Trigonometric cosine + TAN(3), // Trigonometric tangent + ARCSIN(4), // Trigonometric sine inverse + ARCCOS(5), // Trigonometric cosine inverse + ARCTAN(6), // Trigonometric tangent inverse + SINH(7), // Hyperbolic sine + COSH(8), // Hyperbolic cosine + TANH(9), // Hyperbolic tangent + ARCSINH(10), // Hyperbolic sine inverse + ARCCOSH(11), // Hyperbolic cosine inverse + ARCTANH(12), // Hyperbolic tangent inverse + EXP(13), // Exponential (base e, Euler number) + LOG(14), // Natural Logarithm (base e) + SQRT(15), // Square-root (x^0.5) + CBRT(16), // Cube-root (x^(1.0/3)) + CEIL(17), // Smallest integer value not less than arg + FLOOR(18), // largest integer value not greater than arg + ABS(19), // Absolute value + RINT(20), // Rounds the floating-point argument arg to an integer value + BIT_INVERT(21), // Bitwise Not (~) + NOT(22); // Logical Not (!) + + private final byte nativeId; + + UnaryOperator(int nativeId) { + this.nativeId = (byte) nativeId; + assert this.nativeId == nativeId; + } + /** Get the size in bytes to serialize this operator */ + int getSerializedSize() { + return Byte.BYTES; + } + + /** Serialize this operator to the specified buffer */ + void serialize(ByteBuffer bb) { + bb.put(nativeId); + } +} diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index c018c0aa742..a938a2af456 100755 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -261,6 +261,7 @@ set(SOURCE_FILES "src/CudaJni.cpp" "src/ColumnVectorJni.cpp" "src/ColumnViewJni.cpp" + "src/CompiledExpression.cpp" "src/ContiguousTableJni.cpp" "src/HostMemoryBufferNativeUtilsJni.cpp" "src/NvcompJni.cpp" diff --git a/java/src/main/native/src/CompiledExpression.cpp b/java/src/main/native/src/CompiledExpression.cpp new file mode 100644 index 00000000000..a28160b32a3 --- /dev/null +++ b/java/src/main/native/src/CompiledExpression.cpp @@ -0,0 +1,437 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cudf_jni_apis.hpp" + +namespace cudf { +namespace jni { +namespace ast { + +/** + * A class to capture all of the resources associated with a compiled AST expression. + * AST nodes do not own their child nodes, so every node in the expression tree + * must be explicitly tracked in order to free the underlying resources for each node. + * + * This should be cleaned up a bit after the libcudf AST refactoring in + * https://github.com/rapidsai/cudf/pull/8815 when a virtual destructor is added to the + * base AST node type. Then we do not have to track every AST node type separately. + */ +class compiled_expr { + /** All literal nodes within the expression tree */ + std::vector> literals; + + /** All column reference nodes within the expression tree */ + std::vector> column_refs; + + /** All expression nodes within the expression tree */ + std::vector> expressions; + + /** GPU scalar instances that correspond to literal nodes */ + std::vector> scalars; + +public: + cudf::ast::literal &add_literal(std::unique_ptr literal_ptr, + std::unique_ptr scalar_ptr) { + literals.push_back(std::move(literal_ptr)); + scalars.push_back(std::move(scalar_ptr)); + return *literals.back(); + } + + cudf::ast::column_reference & + add_column_ref(std::unique_ptr ref_ptr) { + column_refs.push_back(std::move(ref_ptr)); + return *column_refs.back(); + } + + cudf::ast::expression &add_expression(std::unique_ptr expr_ptr) { + expressions.push_back(std::move(expr_ptr)); + return *expressions.back(); + } + + /** Return the expression node at the top of the tree */ + cudf::ast::expression &get_top_expression() const { return *expressions.back(); } +}; + +} // namespace ast +} // namespace jni +} // namespace cudf + +namespace { + +/** Utility class to read data from the serialized AST buffer generated from Java */ +class jni_serialized_ast { + jbyte const *data_ptr; // pointer to the current entity to deserialize + jbyte const *const end_ptr; // pointer to the byte immediately after the AST serialized data + + /** Throws an error if there is insufficient space left to read the specified number of bytes */ + void check_for_eof(std::size_t num_bytes_to_read) { + if (data_ptr + num_bytes_to_read > end_ptr) { + throw std::runtime_error("Unexpected end of serialized data"); + } + } + +public: + jni_serialized_ast(cudf::jni::native_jbyteArray &jni_data) + : data_ptr(jni_data.begin()), end_ptr(jni_data.end()) {} + + /** Returns true if there is no data remaining to be read */ + bool at_eof() { return data_ptr == end_ptr; } + + /** Read a byte from the serialized AST data buffer */ + jbyte read_byte() { + check_for_eof(sizeof(jbyte)); + return *data_ptr++; + } + + /** Read a multi-byte value from the serialized AST data buffer */ + template T read() { + check_for_eof(sizeof(T)); + // use memcpy since data may be misaligned + T result; + memcpy(&result, data_ptr, sizeof(T)); + data_ptr += sizeof(T); + return result; + } + + /** Decode a libcudf data type from the serialized AST data buffer */ + cudf::data_type read_cudf_type() { + auto const dtype_id = static_cast(read_byte()); + switch (dtype_id) { + case cudf::type_id::INT8: + case cudf::type_id::INT16: + case cudf::type_id::INT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT8: + case cudf::type_id::UINT16: + case cudf::type_id::UINT32: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: + case cudf::type_id::BOOL8: + case cudf::type_id::TIMESTAMP_DAYS: + case cudf::type_id::TIMESTAMP_SECONDS: + case cudf::type_id::TIMESTAMP_MILLISECONDS: + case cudf::type_id::TIMESTAMP_MICROSECONDS: + case cudf::type_id::TIMESTAMP_NANOSECONDS: + case cudf::type_id::DURATION_DAYS: + case cudf::type_id::DURATION_SECONDS: + case cudf::type_id::DURATION_MILLISECONDS: + case cudf::type_id::DURATION_MICROSECONDS: + case cudf::type_id::DURATION_NANOSECONDS: + case cudf::type_id::STRING: { + return cudf::data_type(dtype_id); + } + case cudf::type_id::DECIMAL32: + case cudf::type_id::DECIMAL64: { + int32_t const scale = read_byte(); + return cudf::data_type(dtype_id, scale); + } + default: throw new std::invalid_argument("unrecognized cudf data type"); + } + } +}; + +/** + * Enumeration of the AST node types that can appear in the serialized data. + * NOTE: This must be kept in sync with the NodeType enumeration in AstNode.java! + */ +enum class jni_serialized_node_type : int8_t { + VALID_LITERAL = 0, + NULL_LITERAL = 1, + COLUMN_REFERENCE = 2, + UNARY_EXPRESSION = 3, + BINARY_EXPRESSION = 4 +}; + +/** + * Convert a Java AST serialized byte representing an AST unary operator into the + * corresponding libcudf AST operator. + * NOTE: This must be kept in sync with the enumeration in UnaryOperator.java! + */ +cudf::ast::ast_operator jni_to_unary_operator(jbyte jni_op_value) { + switch (jni_op_value) { + case 0: return cudf::ast::ast_operator::IDENTITY; + case 1: return cudf::ast::ast_operator::SIN; + case 2: return cudf::ast::ast_operator::COS; + case 3: return cudf::ast::ast_operator::TAN; + case 4: return cudf::ast::ast_operator::ARCSIN; + case 5: return cudf::ast::ast_operator::ARCCOS; + case 6: return cudf::ast::ast_operator::ARCTAN; + case 7: return cudf::ast::ast_operator::SINH; + case 8: return cudf::ast::ast_operator::COSH; + case 9: return cudf::ast::ast_operator::TANH; + case 10: return cudf::ast::ast_operator::ARCSINH; + case 11: return cudf::ast::ast_operator::ARCCOSH; + case 12: return cudf::ast::ast_operator::ARCTANH; + case 13: return cudf::ast::ast_operator::EXP; + case 14: return cudf::ast::ast_operator::LOG; + case 15: return cudf::ast::ast_operator::SQRT; + case 16: return cudf::ast::ast_operator::CBRT; + case 17: return cudf::ast::ast_operator::CEIL; + case 18: return cudf::ast::ast_operator::FLOOR; + case 19: return cudf::ast::ast_operator::ABS; + case 20: return cudf::ast::ast_operator::RINT; + case 21: return cudf::ast::ast_operator::BIT_INVERT; + case 22: return cudf::ast::ast_operator::NOT; + default: throw std::invalid_argument("unexpected JNI AST unary operator value"); + } +} + +/** + * Convert a Java AST serialized byte representing an AST binary operator into the + * corresponding libcudf AST operator. + * NOTE: This must be kept in sync with the enumeration in BinaryOperator.java! + */ +cudf::ast::ast_operator jni_to_binary_operator(jbyte jni_op_value) { + switch (jni_op_value) { + case 0: return cudf::ast::ast_operator::ADD; + case 1: return cudf::ast::ast_operator::SUB; + case 2: return cudf::ast::ast_operator::MUL; + case 3: return cudf::ast::ast_operator::DIV; + case 4: return cudf::ast::ast_operator::TRUE_DIV; + case 5: return cudf::ast::ast_operator::FLOOR_DIV; + case 6: return cudf::ast::ast_operator::MOD; + case 7: return cudf::ast::ast_operator::PYMOD; + case 8: return cudf::ast::ast_operator::POW; + case 9: return cudf::ast::ast_operator::EQUAL; + case 10: return cudf::ast::ast_operator::NOT_EQUAL; + case 11: return cudf::ast::ast_operator::LESS; + case 12: return cudf::ast::ast_operator::GREATER; + case 13: return cudf::ast::ast_operator::LESS_EQUAL; + case 14: return cudf::ast::ast_operator::GREATER_EQUAL; + case 15: return cudf::ast::ast_operator::BITWISE_AND; + case 16: return cudf::ast::ast_operator::BITWISE_OR; + case 17: return cudf::ast::ast_operator::BITWISE_XOR; + case 18: return cudf::ast::ast_operator::LOGICAL_AND; + case 19: return cudf::ast::ast_operator::LOGICAL_OR; + default: throw std::invalid_argument("unexpected JNI AST binary operator value"); + } +} + +/** + * Convert a Java AST serialized byte representing an AST table reference into the + * corresponding libcudf AST table reference. + * NOTE: This must be kept in sync with the enumeration in TableReference.java! + */ +cudf::ast::table_reference jni_to_table_reference(jbyte jni_value) { + switch (jni_value) { + case 0: return cudf::ast::table_reference::LEFT; + case 1: return cudf::ast::table_reference::RIGHT; + default: throw std::invalid_argument("unexpected JNI table reference value"); + } +} + +/** Functor for type-dispatching the creation of an AST literal */ +struct make_literal { + /** Construct an AST literal from a numeric value */ + template ()> * = nullptr> + cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid, + cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + std::unique_ptr scalar_ptr = cudf::make_numeric_scalar(dtype); + scalar_ptr->set_valid_async(is_valid); + if (is_valid) { + T val = jni_ast.read(); + using ScalarType = cudf::scalar_type_t; + static_cast(scalar_ptr.get())->set_value(val); + } + + auto &numeric_scalar = static_cast &>(*scalar_ptr); + return compiled_expr.add_literal(std::make_unique(numeric_scalar), + std::move(scalar_ptr)); + } + + /** Construct an AST literal from a timestamp value */ + template ()> * = nullptr> + cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid, + cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + std::unique_ptr scalar_ptr = cudf::make_timestamp_scalar(dtype); + scalar_ptr->set_valid_async(is_valid); + if (is_valid) { + T val = jni_ast.read(); + using ScalarType = cudf::scalar_type_t; + static_cast(scalar_ptr.get())->set_value(val); + } + + auto ×tamp_scalar = static_cast &>(*scalar_ptr); + return compiled_expr.add_literal(std::make_unique(timestamp_scalar), + std::move(scalar_ptr)); + } + + /** Construct an AST literal from a duration value */ + template ()> * = nullptr> + cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid, + cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + std::unique_ptr scalar_ptr = cudf::make_duration_scalar(dtype); + scalar_ptr->set_valid_async(is_valid); + if (is_valid) { + T val = jni_ast.read(); + using ScalarType = cudf::scalar_type_t; + static_cast(scalar_ptr.get())->set_value(val); + } + + auto &duration_scalar = static_cast &>(*scalar_ptr); + return compiled_expr.add_literal(std::make_unique(duration_scalar), + std::move(scalar_ptr)); + } + + /** Default functor implementation to catch type dispatch errors */ + template () && !cudf::is_timestamp() && + !cudf::is_duration()> * = nullptr> + cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid, + cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + throw std::logic_error("Unsupported AST literal type"); + } +}; + +/** Decode a serialized AST literal */ +cudf::ast::literal &compile_literal(bool is_valid, cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + auto const dtype = jni_ast.read_cudf_type(); + return cudf::type_dispatcher(dtype, make_literal{}, dtype, is_valid, compiled_expr, jni_ast); +} + +/** Decode a serialized AST column reference */ +cudf::ast::column_reference &compile_column_reference(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + auto const table_ref = jni_to_table_reference(jni_ast.read_byte()); + cudf::size_type const column_index = jni_ast.read(); + return compiled_expr.add_column_ref( + std::make_unique(column_index, table_ref)); +} + +// forward declaration +cudf::ast::detail::node &compile_node(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast); + +/** Decode a serialized AST unary expression */ +cudf::ast::expression &compile_unary_expression(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + auto const ast_op = jni_to_unary_operator(jni_ast.read_byte()); + cudf::ast::detail::node &child_node = compile_node(compiled_expr, jni_ast); + return compiled_expr.add_expression(std::make_unique(ast_op, child_node)); +} + +/** Decode a serialized AST binary expression */ +cudf::ast::expression &compile_binary_expression(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + auto const ast_op = jni_to_binary_operator(jni_ast.read_byte()); + cudf::ast::detail::node &left_child = compile_node(compiled_expr, jni_ast); + cudf::ast::detail::node &right_child = compile_node(compiled_expr, jni_ast); + return compiled_expr.add_expression( + std::make_unique(ast_op, left_child, right_child)); +} + +/** Decode a serialized AST node by reading the node type and dispatching */ +cudf::ast::detail::node &compile_node(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + auto const node_type = static_cast(jni_ast.read_byte()); + switch (node_type) { + case jni_serialized_node_type::VALID_LITERAL: + return compile_literal(true, compiled_expr, jni_ast); + case jni_serialized_node_type::NULL_LITERAL: + return compile_literal(false, compiled_expr, jni_ast); + case jni_serialized_node_type::COLUMN_REFERENCE: + return compile_column_reference(compiled_expr, jni_ast); + case jni_serialized_node_type::UNARY_EXPRESSION: + return compile_unary_expression(compiled_expr, jni_ast); + case jni_serialized_node_type::BINARY_EXPRESSION: + return compile_binary_expression(compiled_expr, jni_ast); + default: throw std::invalid_argument("data is not a serialized AST expression"); + } +} + +/** Decode a serialized AST into a native libcudf AST and associated resources */ +std::unique_ptr compile_serialized_ast(jni_serialized_ast &jni_ast) { + auto jni_expr_ptr = std::make_unique(); + auto const node_type = static_cast(jni_ast.read_byte()); + switch (node_type) { + case jni_serialized_node_type::UNARY_EXPRESSION: + (void)compile_unary_expression(*jni_expr_ptr, jni_ast); + break; + case jni_serialized_node_type::BINARY_EXPRESSION: + (void)compile_binary_expression(*jni_expr_ptr, jni_ast); + break; + default: throw std::invalid_argument("data is not a serialized AST expression"); + } + + if (!jni_ast.at_eof()) { + throw std::invalid_argument("Extra bytes at end of serialized AST"); + } + + return jni_expr_ptr; +} + +} // anonymous namespace + +extern "C" { + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ast_CompiledExpression_compile(JNIEnv *env, jclass, + jbyteArray jni_data) { + JNI_NULL_CHECK(env, jni_data, "Serialized AST data is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::jni::native_jbyteArray jbytes(env, jni_data); + jni_serialized_ast jni_ast(jbytes); + auto compiled_expr_ptr = compile_serialized_ast(jni_ast); + jbytes.cancel(); + return reinterpret_cast(compiled_expr_ptr.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ast_CompiledExpression_computeColumn(JNIEnv *env, + jclass, + jlong j_ast, + jlong j_table) { + JNI_NULL_CHECK(env, j_ast, "Compiled AST pointer is null", 0); + JNI_NULL_CHECK(env, j_table, "Table view pointer is null", 0); + try { + cudf::jni::auto_set_device(env); + auto compiled_expr_ptr = reinterpret_cast(j_ast); + auto tview_ptr = reinterpret_cast(j_table); + std::unique_ptr result = + cudf::ast::compute_column(*tview_ptr, compiled_expr_ptr->get_top_expression()); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_ast_CompiledExpression_destroy(JNIEnv *env, jclass, + jlong jni_handle) { + try { + cudf::jni::auto_set_device(env); + auto ptr = reinterpret_cast(jni_handle); + delete ptr; + } + CATCH_STD(env, ); +} + +} // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java new file mode 100644 index 00000000000..5a64fd6ab09 --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java @@ -0,0 +1,576 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf.ast; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.CudfException; +import ai.rapids.cudf.CudfTestBase; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.Table; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; + +import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; + +public class CompiledExpressionTest extends CudfTestBase { + @Test + public void testColumnReferenceTransform() { + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build()) { + // use an implicit table reference + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, + new ColumnReference(1)); + try (CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t)) { + assertColumnsAreEqual(t.getColumn(1), actual); + } + + // use an explicit table reference + expr = new UnaryExpression(UnaryOperator.IDENTITY, + new ColumnReference(1, TableReference.LEFT)); + try (CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t)) { + assertColumnsAreEqual(t.getColumn(1), actual); + } + } + } + + @Test + public void testInvalidColumnReferenceTransform() { + // verify attempting to reference an invalid table remaps to the only valid table + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, + new ColumnReference(1, TableReference.RIGHT)); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t)) { + assertColumnsAreEqual(t.getColumn(1), actual); + } + } + + @Test + public void testBooleanLiteralTransform() { + try (Table t = new Table.TestBuilder().column(true, false, null).build()) { + Literal trueLiteral = Literal.ofBoolean(true); + UnaryExpression trueExpr = new UnaryExpression(UnaryOperator.IDENTITY, trueLiteral); + try (CompiledExpression trueCompiledExpr = trueExpr.compile(); + ColumnVector trueExprActual = trueCompiledExpr.computeColumn(t); + ColumnVector trueExprExpected = ColumnVector.fromBoxedBooleans(true, true, true)) { + assertColumnsAreEqual(trueExprExpected, trueExprActual); + } + + // Uncomment the following after https://github.com/rapidsai/cudf/issues/8831 is fixed + // Literal nullLiteral = Literal.ofBoolean(null); + // UnaryExpression nullExpr = new UnaryExpression(AstOperator.IDENTITY, nullLiteral); + // try (CompiledExpression nullCompiledExpr = nullExpr.compile(); + // ColumnVector nullExprActual = nullCompiledExpr.computeColumn(t); + // ColumnVector nullExprExpected = ColumnVector.fromBoxedBooleans(null, null, null)) { + // assertColumnsAreEqual(nullExprExpected, nullExprActual); + // } + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(bytes = 0x12) + public void testByteLiteralTransform(Byte value) { + Literal literal = Literal.ofByte(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedBytes(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(shorts = 0x1234) + public void testShortLiteralTransform(Short value) { + Literal literal = Literal.ofShort(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedShorts(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(ints = 0x12345678) + public void testIntLiteralTransform(Integer value) { + Literal literal = Literal.ofInt(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedInts(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testLongLiteralTransform(Long value) { + Literal literal = Literal.ofLong(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(floats = { 123456.789f, Float.NaN, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY} ) + public void testFloatLiteralTransform(Float value) { + Literal literal = Literal.ofFloat(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedFloats(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(doubles = { 123456.789f, Double.NaN, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY} ) + public void testDoubleLiteralTransform(Double value) { + Literal literal = Literal.ofDouble(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedDoubles(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(ints = 0x12345678) + public void testTimestampDaysLiteralTransform(Integer value) { + Literal literal = Literal.ofTimestampDaysFromInt(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.timestampDaysFromBoxedInts(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testTimestampSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_SECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.timestampSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testTimestampMilliSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_MILLISECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.timestampMilliSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testTimestampMicroSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_MICROSECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.timestampMicroSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testTimestampNanoSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_NANOSECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.timestampNanoSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(ints = 0x12345678) + public void testDurationDaysLiteralTransform(Integer value) { + Literal literal = Literal.ofDurationDaysFromInt(value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.durationDaysFromBoxedInts(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testDurationSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofDurationFromLong(DType.DURATION_SECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.durationSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testDurationMilliSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofDurationFromLong(DType.DURATION_MILLISECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.durationMilliSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testDurationMicroSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofDurationFromLong(DType.DURATION_MICROSECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.durationMicroSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + @ParameterizedTest + // Uncomment the following line after https://github.com/rapidsai/cudf/issues/8831 is fixed + // @NullSource + @ValueSource(longs = 0x1234567890abcdefL) + public void testDurationNanoSecondsLiteralTransform(Long value) { + Literal literal = Literal.ofDurationFromLong(DType.DURATION_NANOSECONDS, value); + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = + ColumnVector.durationNanoSecondsFromBoxedLongs(value, value, value, value, value)) { + assertColumnsAreEqual(expected, actual); + } + } + + private static ArrayList mapArray(T[] input, Function func) { + ArrayList result = new ArrayList<>(input.length); + for (T t : input) { + result.add(t == null ? null : func.apply(t)); + } + return result; + } + + private static ArrayList mapArray(T[] in1, U[] in2, BiFunction func) { + assert in1.length == in2.length; + ArrayList result = new ArrayList<>(in1.length); + for (int i = 0; i < in1.length; i++) { + result.add(in1[i] == null || in2[i] == null ? null : func.apply(in1[i], in2[i])); + } + return result; + } + + private static Stream createUnaryDoubleExpressionParams() { + Double[] input = new Double[] { -5., 4.5, null, 2.7, 1.5 }; + return Stream.of( + Arguments.of(UnaryOperator.IDENTITY, input, Arrays.asList(input)), + Arguments.of(UnaryOperator.SIN, input, mapArray(input, Math::sin)), + Arguments.of(UnaryOperator.COS, input, mapArray(input, Math::cos)), + Arguments.of(UnaryOperator.TAN, input, mapArray(input, Math::tan)), + Arguments.of(UnaryOperator.ARCSIN, input, mapArray(input, Math::asin)), + Arguments.of(UnaryOperator.ARCCOS, input, mapArray(input, Math::acos)), + Arguments.of(UnaryOperator.ARCTAN, input, mapArray(input, Math::atan)), + Arguments.of(UnaryOperator.SINH, input, mapArray(input, Math::sinh)), + Arguments.of(UnaryOperator.COSH, input, mapArray(input, Math::cosh)), + Arguments.of(UnaryOperator.TANH, input, mapArray(input, Math::tanh)), + Arguments.of(UnaryOperator.EXP, input, mapArray(input, Math::exp)), + Arguments.of(UnaryOperator.LOG, input, mapArray(input, Math::log)), + Arguments.of(UnaryOperator.SQRT, input, mapArray(input, Math::sqrt)), + Arguments.of(UnaryOperator.CBRT, input, mapArray(input, Math::cbrt)), + Arguments.of(UnaryOperator.CEIL, input, mapArray(input, Math::ceil)), + Arguments.of(UnaryOperator.FLOOR, input, mapArray(input, Math::floor)), + Arguments.of(UnaryOperator.ABS, input, mapArray(input, Math::abs)), + Arguments.of(UnaryOperator.RINT, input, mapArray(input, Math::rint))); + } + + @ParameterizedTest + @MethodSource("createUnaryDoubleExpressionParams") + void testUnaryDoubleExpressionTransform(UnaryOperator op, Double[] input, + List expectedValues) { + UnaryExpression expr = new UnaryExpression(op, new ColumnReference(0)); + try (Table t = new Table.TestBuilder().column(input).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedDoubles( + expectedValues.toArray(new Double[0]))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testUnaryShortExpressionTransform() { + Short[] input = new Short[] { -5, 4, null, 2, 1 }; + try (Table t = new Table.TestBuilder().column(input).build()) { + UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, new ColumnReference(0)); + try (CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t)) { + assertColumnsAreEqual(t.getColumn(0), actual); + } + + expr = new UnaryExpression(UnaryOperator.BIT_INVERT, new ColumnReference(0)); + try (CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedInts(4, -5, null, -3, -2)) { + assertColumnsAreEqual(expected, actual); + } + } + } + + @Test + void testUnaryLogicalExpressionTransform() { + UnaryExpression expr = new UnaryExpression(UnaryOperator.NOT, new ColumnReference(0)); + try (Table t = new Table.TestBuilder().column(-5L, 0L, null, 2L, 1L).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, null, false, false)) { + assertColumnsAreEqual(expected, actual); + } + } + + private static Stream createBinaryFloatExpressionParams() { + Float[] in1 = new Float[] { -5f, 4.5f, null, 2.7f }; + Float[] in2 = new Float[] { 123f, -456f, null, 0f }; + return Stream.of( + Arguments.of(BinaryOperator.ADD, in1, in2, mapArray(in1, in2, Float::sum)), + Arguments.of(BinaryOperator.SUB, in1, in2, mapArray(in1, in2, (a, b) -> a - b)), + Arguments.of(BinaryOperator.MUL, in1, in2, mapArray(in1, in2, (a, b) -> a * b)), + Arguments.of(BinaryOperator.DIV, in1, in2, mapArray(in1, in2, (a, b) -> a / b)), + Arguments.of(BinaryOperator.MOD, in1, in2, mapArray(in1, in2, (a, b) -> a % b)), + Arguments.of(BinaryOperator.PYMOD, in1, in2, mapArray(in1, in2, + (a, b) -> ((a % b) + b) % b)), + Arguments.of(BinaryOperator.POW, in1, in2, mapArray(in1, in2, + (a, b) -> (float) Math.pow(a, b)))); + } + + @ParameterizedTest + @MethodSource("createBinaryFloatExpressionParams") + void testBinaryFloatExpressionTransform(BinaryOperator op, Float[] in1, Float[] in2, + List expectedValues) { + BinaryExpression expr = new BinaryExpression(op, + new ColumnReference(0), + new ColumnReference(1)); + try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedFloats( + expectedValues.toArray(new Float[0]))) { + assertColumnsAreEqual(expected, actual); + } + } + + private static Stream createBinaryDoublePromotedExpressionParams() { + Float[] in1 = new Float[] { -5f, 4.5f, null, 2.7f }; + Float[] in2 = new Float[] { 123f, -456f, null, 0f }; + return Stream.of( + Arguments.of(BinaryOperator.TRUE_DIV, in1, in2, mapArray(in1, in2, + (a, b) -> (double) a / b)), + Arguments.of(BinaryOperator.FLOOR_DIV, in1, in2, mapArray(in1, in2, + (a, b) -> Math.floor(a / b)))); + } + + @ParameterizedTest + @MethodSource("createBinaryDoublePromotedExpressionParams") + void testBinaryDoublePromotedExpressionTransform(BinaryOperator op, Float[] in1, Float[] in2, + List expectedValues) { + BinaryExpression expr = new BinaryExpression(op, + new ColumnReference(0), + new ColumnReference(1)); + try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedDoubles( + expectedValues.toArray(new Double[0]))) { + assertColumnsAreEqual(expected, actual); + } + } + + private static Stream createBinaryComparisonExpressionParams() { + Integer[] in1 = new Integer[] { -5, 4, null, 2, -3 }; + Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 }; + return Stream.of( + // nulls compare as equal by default + Arguments.of(BinaryOperator.EQUAL, in1, in2, Arrays.asList(false, false, true, false, true)), + Arguments.of(BinaryOperator.NOT_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> !a.equals(b))), + Arguments.of(BinaryOperator.LESS, in1, in2, mapArray(in1, in2, (a, b) -> a < b)), + Arguments.of(BinaryOperator.GREATER, in1, in2, mapArray(in1, in2, (a, b) -> a > b)), + Arguments.of(BinaryOperator.LESS_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> a <= b)), + Arguments.of(BinaryOperator.GREATER_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> a >= b))); + } + + @ParameterizedTest + @MethodSource("createBinaryComparisonExpressionParams") + void testBinaryComparisonExpressionTransform(BinaryOperator op, Integer[] in1, Integer[] in2, + List expectedValues) { + BinaryExpression expr = new BinaryExpression(op, + new ColumnReference(0), + new ColumnReference(1)); + try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedBooleans( + expectedValues.toArray(new Boolean[0]))) { + assertColumnsAreEqual(expected, actual); + } + } + + private static Stream createBinaryBitwiseExpressionParams() { + Integer[] in1 = new Integer[] { -5, 4, null, 2, -3 }; + Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 }; + return Stream.of( + Arguments.of(BinaryOperator.BITWISE_AND, in1, in2, mapArray(in1, in2, (a, b) -> a & b)), + Arguments.of(BinaryOperator.BITWISE_OR, in1, in2, mapArray(in1, in2, (a, b) -> a | b)), + Arguments.of(BinaryOperator.BITWISE_XOR, in1, in2, mapArray(in1, in2, (a, b) -> a ^ b))); + } + + @ParameterizedTest + @MethodSource("createBinaryBitwiseExpressionParams") + void testBinaryBitwiseExpressionTransform(BinaryOperator op, Integer[] in1, Integer[] in2, + List expectedValues) { + BinaryExpression expr = new BinaryExpression(op, + new ColumnReference(0), + new ColumnReference(1)); + try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedInts( + expectedValues.toArray(new Integer[0]))) { + assertColumnsAreEqual(expected, actual); + } + } + + private static Stream createBinaryBooleanExpressionParams() { + Boolean[] in1 = new Boolean[] { false, true, null, true, false }; + Boolean[] in2 = new Boolean[] { true, null, null, true, false }; + return Stream.of( + Arguments.of(BinaryOperator.LOGICAL_AND, in1, in2, mapArray(in1, in2, (a, b) -> a && b)), + Arguments.of(BinaryOperator.LOGICAL_OR, in1, in2, mapArray(in1, in2, (a, b) -> a || b))); + } + + @ParameterizedTest + @MethodSource("createBinaryBooleanExpressionParams") + void testBinaryBooleanExpressionTransform(BinaryOperator op, Boolean[] in1, Boolean[] in2, + List expectedValues) { + BinaryExpression expr = new BinaryExpression(op, + new ColumnReference(0), + new ColumnReference(1)); + try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); + CompiledExpression compiledExpr = expr.compile(); + ColumnVector actual = compiledExpr.computeColumn(t); + ColumnVector expected = ColumnVector.fromBoxedBooleans( + expectedValues.toArray(new Boolean[0]))) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testMismatchedBinaryExpressionTypes() { + // verify expression fails to transform if operands are not the same type + BinaryExpression expr = new BinaryExpression(BinaryOperator.ADD, + new ColumnReference(0), + new ColumnReference(1)); + try (Table t = new Table.TestBuilder().column(1, 2, 3).column(1L, 2L, 3L).build(); + CompiledExpression compiledExpr = expr.compile()) { + Assertions.assertThrows(CudfException.class, () -> compiledExpr.computeColumn(t).close()); + } + } +}