diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index c1ba2b495eb..208c21c2dc0 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -53,8 +53,7 @@ test: - test -f $PREFIX/include/cudf/aggregation.hpp - test -f $PREFIX/include/cudf/ast/detail/expression_parser.hpp - test -f $PREFIX/include/cudf/ast/detail/operators.hpp - - test -f $PREFIX/include/cudf/ast/nodes.hpp - - test -f $PREFIX/include/cudf/ast/operators.hpp + - test -f $PREFIX/include/cudf/ast/expressions.hpp - test -f $PREFIX/include/cudf/binaryop.hpp - test -f $PREFIX/include/cudf/labeling/label_bins.hpp - test -f $PREFIX/include/cudf/column/column_factories.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 54a4c4ea023..0e98258d7c2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -153,6 +153,7 @@ add_library(cudf src/aggregation/aggregation.cu src/aggregation/result_cache.cpp src/ast/expression_parser.cpp + src/ast/expressions.cpp src/binaryop/binaryop.cpp src/binaryop/compiled/binary_ops.cu src/binaryop/compiled/Add.cu diff --git a/cpp/benchmarks/ast/transform_benchmark.cpp b/cpp/benchmarks/ast/transform_benchmark.cpp index 75b502bf7bf..fd0a0f7d2c8 100644 --- a/cpp/benchmarks/ast/transform_benchmark.cpp +++ b/cpp/benchmarks/ast/transform_benchmark.cpp @@ -95,22 +95,22 @@ static void BM_ast_transform(benchmark::State& state) // Note that a std::list is required here because of its guarantees against reference invalidation // when items are added or removed. References to items in a std::vector are not safe if the // vector must re-allocate. - auto expressions = std::list(); + auto expressions = std::list(); // Construct tree that chains additions like (((a + b) + c) + d) auto const op = cudf::ast::ast_operator::ADD; if (reuse_columns) { - expressions.push_back(cudf::ast::expression(op, column_refs.at(0), column_refs.at(0))); + expressions.push_back(cudf::ast::operation(op, column_refs.at(0), column_refs.at(0))); for (cudf::size_type i = 0; i < tree_levels - 1; i++) { - expressions.push_back(cudf::ast::expression(op, expressions.back(), column_refs.at(0))); + expressions.push_back(cudf::ast::operation(op, expressions.back(), column_refs.at(0))); } } else { - expressions.push_back(cudf::ast::expression(op, column_refs.at(0), column_refs.at(1))); + expressions.push_back(cudf::ast::operation(op, column_refs.at(0), column_refs.at(1))); std::transform(std::next(column_refs.cbegin(), 2), column_refs.cend(), std::back_inserter(expressions), [&](auto const& column_ref) { - return cudf::ast::expression(op, expressions.back(), column_ref); + return cudf::ast::operation(op, expressions.back(), column_ref); }); } diff --git a/cpp/benchmarks/join/conditional_join_benchmark.cu b/cpp/benchmarks/join/conditional_join_benchmark.cu index f778f6ac010..71b90685fb9 100644 --- a/cpp/benchmarks/join/conditional_join_benchmark.cu +++ b/cpp/benchmarks/join/conditional_join_benchmark.cu @@ -26,7 +26,7 @@ class ConditionalJoin : public cudf::benchmark { { \ auto join = [](cudf::table_view const& left, \ cudf::table_view const& right, \ - cudf::ast::expression binary_pred, \ + cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ return cudf::conditional_inner_join(left, right, binary_pred, compare_nulls); \ }; \ @@ -45,7 +45,7 @@ CONDITIONAL_INNER_JOIN_BENCHMARK_DEFINE(conditional_inner_join_64bit_nulls, int6 { \ auto join = [](cudf::table_view const& left, \ cudf::table_view const& right, \ - cudf::ast::expression binary_pred, \ + cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ return cudf::conditional_left_join(left, right, binary_pred, compare_nulls); \ }; \ @@ -64,7 +64,7 @@ CONDITIONAL_LEFT_JOIN_BENCHMARK_DEFINE(conditional_left_join_64bit_nulls, int64_ { \ auto join = [](cudf::table_view const& left, \ cudf::table_view const& right, \ - cudf::ast::expression binary_pred, \ + cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ return cudf::conditional_inner_join(left, right, binary_pred, compare_nulls); \ }; \ @@ -83,7 +83,7 @@ CONDITIONAL_FULL_JOIN_BENCHMARK_DEFINE(conditional_full_join_64bit_nulls, int64_ { \ auto join = [](cudf::table_view const& left, \ cudf::table_view const& right, \ - cudf::ast::expression binary_pred, \ + cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ return cudf::conditional_left_anti_join(left, right, binary_pred, compare_nulls); \ }; \ @@ -114,7 +114,7 @@ CONDITIONAL_LEFT_ANTI_JOIN_BENCHMARK_DEFINE(conditional_left_anti_join_64bit_nul { \ auto join = [](cudf::table_view const& left, \ cudf::table_view const& right, \ - cudf::ast::expression binary_pred, \ + cudf::ast::operation binary_pred, \ cudf::null_equality compare_nulls) { \ return cudf::conditional_left_semi_join(left, right, binary_pred, compare_nulls); \ }; \ @@ -145,11 +145,6 @@ BENCHMARK_REGISTER_F(ConditionalJoin, conditional_inner_join_32bit) ->Args({100'000, 100'000}) ->Args({100'000, 400'000}) ->Args({100'000, 1'000'000}) - // TODO: The below benchmark is slow, but can be useful to validate that the - // code works for large data sets. This benchmark was used to compare to the - // otherwise equivalent nullable benchmark below, which has memory errors for - // sufficiently large data sets. - //->Args({1'000'000, 1'000'000}) ->UseManualTime(); BENCHMARK_REGISTER_F(ConditionalJoin, conditional_inner_join_64bit) diff --git a/cpp/benchmarks/join/join_benchmark_common.hpp b/cpp/benchmarks/join/join_benchmark_common.hpp index e6fed454707..add87bf7dfb 100644 --- a/cpp/benchmarks/join/join_benchmark_common.hpp +++ b/cpp/benchmarks/join/join_benchmark_common.hpp @@ -21,6 +21,7 @@ #include +#include #include #include #include @@ -139,7 +140,7 @@ static void BM_join(state_type& state, Join JoinFunc) const auto col_ref_left_0 = cudf::ast::column_reference(0); const auto col_ref_right_0 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); auto left_zero_eq_right_zero = - cudf::ast::expression(cudf::ast::ast_operator::EQUAL, col_ref_left_0, col_ref_right_0); + cudf::ast::operation(cudf::ast::ast_operator::EQUAL, col_ref_left_0, col_ref_right_0); for (auto _ : state) { cuda_event_timer raii(state, true, rmm::cuda_stream_default); diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index ca2cab96123..fb198761115 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -17,8 +17,7 @@ #include #include -#include -#include +#include #include #include #include diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index 9eca250b898..1f35b54ea61 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -15,8 +15,7 @@ */ #pragma once -#include -#include +#include #include #include #include @@ -44,7 +43,7 @@ enum class device_data_reference_type { }; /** - * @brief A device data reference describes a source of data used by a node. + * @brief A device data reference describes a source of data used by a expression. * * This is a POD class used to create references describing data type and locations for consumption * by the `row_evaluator`. @@ -115,11 +114,11 @@ struct expression_device_view { * @brief The expression_parser traverses an expression and converts it into a form suitable for * execution on the device. * - * This class is part of a "visitor" pattern with the `node` class. + * This class is part of a "visitor" pattern with the `expression` class. * * This class does pre-processing work on the host, validating operators and operand data types. It - * traverses downward from a root node in a depth-first fashion, capturing information about - * the nodes and constructing vectors of information that are later used by the device for + * traverses downward from a root expression in a depth-first fashion, capturing information about + * the expressions and constructing vectors of information that are later used by the device for * evaluating the abstract syntax tree as a "linear" list of operators whose input dependencies are * resolved into intermediate data storage in shared memory. */ @@ -132,13 +131,17 @@ class expression_parser { * @param left The left table used for evaluating the abstract syntax tree. * @param right The right table used for evaluating the abstract syntax tree. */ - expression_parser(node const& expr, + expression_parser(expression const& expr, cudf::table_view const& left, std::optional> right, bool has_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) - : _left{left}, _right{right}, _node_count{0}, _intermediate_counter{}, _has_nulls(has_nulls) + : _left{left}, + _right{right}, + _expression_count{0}, + _intermediate_counter{}, + _has_nulls(has_nulls) { expr.accept(*this); move_to_device(stream, mr); @@ -150,7 +153,7 @@ class expression_parser { * @param expr The expression to create an evaluable expression_parser for. * @param table The table used for evaluating the abstract syntax tree. */ - expression_parser(node const& expr, + expression_parser(expression const& expr, cudf::table_view const& table, bool has_nulls, rmm::cuda_stream_view stream, @@ -167,33 +170,33 @@ class expression_parser { cudf::data_type output_type() const; /** - * @brief Visit a literal node. + * @brief Visit a literal expression. * - * @param expr Literal node. - * @return cudf::size_type Index of device data reference for the node. + * @param expr Literal expression. + * @return cudf::size_type Index of device data reference for the expression. */ cudf::size_type visit(literal const& expr); /** - * @brief Visit a column reference node. + * @brief Visit a column reference expression. * - * @param expr Column reference node. - * @return cudf::size_type Index of device data reference for the node. + * @param expr Column reference expression. + * @return cudf::size_type Index of device data reference for the expression. */ cudf::size_type visit(column_reference const& expr); /** - * @brief Visit an expression node. + * @brief Visit an expression expression. * - * @param expr Expression node. - * @return cudf::size_type Index of device data reference for the node. + * @param expr Expression expression. + * @return cudf::size_type Index of device data reference for the expression. */ - cudf::size_type visit(expression const& expr); + cudf::size_type visit(operation const& expr); /** * @brief Internal class used to track the utilization of intermediate storage locations. * - * As nodes are being evaluated, they may generate "intermediate" data that is immediately + * As expressions are being evaluated, they may generate "intermediate" data that is immediately * consumed. Rather than manifesting this data in global memory, we can store intermediates of any * fixed width type (up to 8 bytes) by placing them in shared memory. This class helps to track * the number and indices of intermediate data in shared memory using a give-take model. Locations @@ -308,7 +311,7 @@ class expression_parser { * @return The indices of the operands stored in the data references. */ std::vector visit_operands( - std::vector> operands); + std::vector> operands); /** * @brief Add a data reference to the internal list. @@ -325,7 +328,7 @@ class expression_parser { cudf::table_view const& _left; std::optional> _right; - cudf::size_type _node_count; + cudf::size_type _expression_count; intermediate_counter _intermediate_counter; bool _has_nulls; std::vector _data_references; diff --git a/cpp/include/cudf/ast/detail/operators.hpp b/cpp/include/cudf/ast/detail/operators.hpp index fd3a0775401..00723004a9f 100644 --- a/cpp/include/cudf/ast/detail/operators.hpp +++ b/cpp/include/cudf/ast/detail/operators.hpp @@ -15,7 +15,7 @@ */ #pragma once -#include +#include #include #include #include diff --git a/cpp/include/cudf/ast/nodes.hpp b/cpp/include/cudf/ast/expressions.hpp similarity index 59% rename from cpp/include/cudf/ast/nodes.hpp rename to cpp/include/cudf/ast/expressions.hpp index f36d7bcd3c7..d9ba197f8fe 100644 --- a/cpp/include/cudf/ast/nodes.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -15,8 +15,6 @@ */ #pragma once -#include -#include #include #include #include @@ -25,21 +23,75 @@ namespace cudf { namespace ast { -namespace detail { -// Forward declaration +// Forward declaration. +namespace detail { class expression_parser; +} + /** - * @brief A generic node that can be evaluated to return a value. + * @brief A generic expression that can be evaluated to return a value. * * This class is a part of a "visitor" pattern with the `linearizer` class. * Nodes inheriting from this class can accept visitors. */ -struct node { - virtual cudf::size_type accept(expression_parser& visitor) const = 0; +struct expression { + virtual cudf::size_type accept(detail::expression_parser& visitor) const = 0; + + virtual ~expression() {} }; -} // namespace detail +/** + * @brief Enum of supported operators. + */ +enum class ast_operator { + // Binary operators + ADD, ///< operator + + SUB, ///< operator - + MUL, ///< operator * + DIV, ///< operator / using common type of lhs and rhs + TRUE_DIV, ///< operator / after promoting type to floating point + FLOOR_DIV, ///< operator / after promoting to 64 bit floating point and then + ///< flooring the result + MOD, ///< operator % + PYMOD, ///< operator % but following python's sign rules for negatives + POW, ///< lhs ^ rhs + EQUAL, ///< operator == + NOT_EQUAL, ///< operator != + LESS, ///< operator < + GREATER, ///< operator > + LESS_EQUAL, ///< operator <= + GREATER_EQUAL, ///< operator >= + BITWISE_AND, ///< operator & + BITWISE_OR, ///< operator | + BITWISE_XOR, ///< operator ^ + LOGICAL_AND, ///< operator && + LOGICAL_OR, ///< operator || + // Unary operators + IDENTITY, ///< Identity function + SIN, ///< Trigonometric sine + COS, ///< Trigonometric cosine + TAN, ///< Trigonometric tangent + ARCSIN, ///< Trigonometric sine inverse + ARCCOS, ///< Trigonometric cosine inverse + ARCTAN, ///< Trigonometric tangent inverse + SINH, ///< Hyperbolic sine + COSH, ///< Hyperbolic cosine + TANH, ///< Hyperbolic tangent + ARCSINH, ///< Hyperbolic sine inverse + ARCCOSH, ///< Hyperbolic cosine inverse + ARCTANH, ///< Hyperbolic tangent inverse + EXP, ///< Exponential (base e, Euler number) + LOG, ///< Natural Logarithm (base e) + SQRT, ///< Square-root (x^0.5) + CBRT, ///< Cube-root (x^(1.0/3)) + CEIL, ///< Smallest integer value not less than arg + FLOOR, ///< largest integer value not greater than arg + ABS, ///< Absolute value + RINT, ///< Rounds the floating-point argument arg to an integer value + BIT_INVERT, ///< Bitwise Not (~) + NOT ///< Logical Not (!) +}; /** * @brief Enum of table references. @@ -55,7 +107,7 @@ enum class table_reference { /** * @brief A literal value used in an abstract syntax tree. */ -class literal : public detail::node { +class literal : public expression { public: /** * @brief Construct a new literal object. @@ -117,14 +169,14 @@ class literal : public detail::node { }; /** - * @brief A node referring to data from a column in a table. + * @brief A expression referring to data from a column in a table. */ -class column_reference : public detail::node { +class column_reference : public expression { public: /** * @brief Construct a new column reference object * - * @param column_index Index of this column in the table (provided when the node is + * @param column_index Index of this column in the table (provided when the expression is * evaluated). * @param table_source Which table to use in cases with two tables (e.g. joins). */ @@ -194,43 +246,33 @@ class column_reference : public detail::node { }; /** - * @brief An expression node holds an operator and zero or more operands. + * @brief An operation expression holds an operator and zero or more operands. */ -class expression : public detail::node { +class operation : public expression { public: /** - * @brief Construct a new unary expression object. + * @brief Construct a new unary operation object. * * @param op Operator - * @param input Input node (operand) + * @param input Input expression (operand) */ - expression(ast_operator op, node const& input) : op(op), operands({input}) - { - if (cudf::ast::detail::ast_operator_arity(op) != 1) { - CUDF_FAIL("The provided operator is not a unary operator."); - } - } + operation(ast_operator op, expression const& input); /** - * @brief Construct a new binary expression object. + * @brief Construct a new binary operation object. * * @param op Operator - * @param left Left input node (left operand) - * @param right Right input node (right operand) + * @param left Left input expression (left operand) + * @param right Right input expression (right operand) */ - expression(ast_operator op, node const& left, node const& right) : op(op), operands({left, right}) - { - if (cudf::ast::detail::ast_operator_arity(op) != 2) { - CUDF_FAIL("The provided operator is not a binary operator."); - } - } + operation(ast_operator op, expression const& left, expression const& right); - // expression only stores references to nodes, so it does not accept r-value - // references: the calling code must own the nodes. - expression(ast_operator op, node&& input) = delete; - expression(ast_operator op, node&& left, node&& right) = delete; - expression(ast_operator op, node&& left, node const& right) = delete; - expression(ast_operator op, node const& left, node&& right) = delete; + // operation only stores references to expressions, so it does not accept r-value + // references: the calling code must own the expressions. + operation(ast_operator op, expression&& input) = delete; + operation(ast_operator op, expression&& left, expression&& right) = delete; + operation(ast_operator op, expression&& left, expression const& right) = delete; + operation(ast_operator op, expression const& left, expression&& right) = delete; /** * @brief Get the operator. @@ -242,9 +284,9 @@ class expression : public detail::node { /** * @brief Get the operands. * - * @return std::vector> + * @return std::vector> */ - std::vector> get_operands() const { return operands; } + std::vector> get_operands() const { return operands; } /** * @brief Accepts a visitor class. @@ -256,7 +298,7 @@ class expression : public detail::node { private: ast_operator const op; - std::vector> const operands; + std::vector> const operands; }; } // namespace ast diff --git a/cpp/include/cudf/ast/operators.hpp b/cpp/include/cudf/ast/operators.hpp deleted file mode 100644 index 78e56340246..00000000000 --- a/cpp/include/cudf/ast/operators.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace cudf { - -namespace ast { - -/** - * @brief Enum of supported operators. - */ -enum class ast_operator { - // Binary operators - ADD, ///< operator + - SUB, ///< operator - - MUL, ///< operator * - DIV, ///< operator / using common type of lhs and rhs - TRUE_DIV, ///< operator / after promoting type to floating point - FLOOR_DIV, ///< operator / after promoting to 64 bit floating point and then - ///< flooring the result - MOD, ///< operator % - PYMOD, ///< operator % but following python's sign rules for negatives - POW, ///< lhs ^ rhs - EQUAL, ///< operator == - NOT_EQUAL, ///< operator != - LESS, ///< operator < - GREATER, ///< operator > - LESS_EQUAL, ///< operator <= - GREATER_EQUAL, ///< operator >= - BITWISE_AND, ///< operator & - BITWISE_OR, ///< operator | - BITWISE_XOR, ///< operator ^ - LOGICAL_AND, ///< operator && - LOGICAL_OR, ///< operator || - // Unary operators - IDENTITY, ///< Identity function - SIN, ///< Trigonometric sine - COS, ///< Trigonometric cosine - TAN, ///< Trigonometric tangent - ARCSIN, ///< Trigonometric sine inverse - ARCCOS, ///< Trigonometric cosine inverse - ARCTAN, ///< Trigonometric tangent inverse - SINH, ///< Hyperbolic sine - COSH, ///< Hyperbolic cosine - TANH, ///< Hyperbolic tangent - ARCSINH, ///< Hyperbolic sine inverse - ARCCOSH, ///< Hyperbolic cosine inverse - ARCTANH, ///< Hyperbolic tangent inverse - EXP, ///< Exponential (base e, Euler number) - LOG, ///< Natural Logarithm (base e) - SQRT, ///< Square-root (x^0.5) - CBRT, ///< Cube-root (x^(1.0/3)) - CEIL, ///< Smallest integer value not less than arg - FLOOR, ///< largest integer value not greater than arg - ABS, ///< Absolute value - RINT, ///< Rounds the floating-point argument arg to an integer value - BIT_INVERT, ///< Bitwise Not (~) - NOT ///< Logical Not (!) -}; - -} // namespace ast - -} // namespace cudf diff --git a/cpp/include/cudf/detail/transform.hpp b/cpp/include/cudf/detail/transform.hpp index 96ef27529be..12948498455 100644 --- a/cpp/include/cudf/detail/transform.hpp +++ b/cpp/include/cudf/detail/transform.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -43,7 +43,7 @@ std::unique_ptr transform( */ std::unique_ptr compute_column( table_view const table, - ast::expression const& expr, + ast::operation const& expr, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/include/cudf/join.hpp b/cpp/include/cudf/join.hpp index dbafa95ee77..483cd75c739 100644 --- a/cpp/include/cudf/join.hpp +++ b/cpp/include/cudf/join.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -687,9 +687,9 @@ class hash_join { std::pair>, std::unique_ptr>> conditional_inner_join( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -733,9 +733,9 @@ conditional_inner_join( */ std::pair>, std::unique_ptr>> -conditional_left_join(table_view left, - table_view right, - ast::expression binary_predicate, +conditional_left_join(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -778,9 +778,9 @@ conditional_left_join(table_view left, */ std::pair>, std::unique_ptr>> -conditional_full_join(table_view left, - table_view right, - ast::expression binary_predicate, +conditional_full_join(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -817,9 +817,9 @@ conditional_full_join(table_view left, * `right` . */ std::unique_ptr> conditional_left_semi_join( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -857,9 +857,9 @@ std::unique_ptr> conditional_left_semi_join( * `right` . */ std::unique_ptr> conditional_left_anti_join( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -883,9 +883,9 @@ std::unique_ptr> conditional_left_anti_join( * @return The size that would result from performing the requested join. */ std::size_t conditional_inner_join_size( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -908,9 +908,9 @@ std::size_t conditional_inner_join_size( * @return The size that would result from performing the requested join. */ std::size_t conditional_left_join_size( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -933,9 +933,9 @@ std::size_t conditional_left_join_size( * @return The size that would result from performing the requested join. */ std::size_t conditional_left_semi_join_size( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); @@ -958,9 +958,9 @@ std::size_t conditional_left_semi_join_size( * @return The size that would result from performing the requested join. */ std::size_t conditional_left_anti_join_size( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls = null_equality::EQUAL, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of group diff --git a/cpp/include/cudf/transform.hpp b/cpp/include/cudf/transform.hpp index cf391b2b23d..af2858d948e 100644 --- a/cpp/include/cudf/transform.hpp +++ b/cpp/include/cudf/transform.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -89,7 +89,7 @@ std::pair, size_type> nans_to_nulls( * @return std::unique_ptr Output column. */ std::unique_ptr compute_column( - table_view const table, + table_view const& table, ast::expression const& expr, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); diff --git a/cpp/src/ast/expression_parser.cpp b/cpp/src/ast/expression_parser.cpp index 760f47a5045..1072bff43dd 100644 --- a/cpp/src/ast/expression_parser.cpp +++ b/cpp/src/ast/expression_parser.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ #include -#include -#include +#include +#include #include #include #include @@ -85,46 +85,57 @@ cudf::size_type expression_parser::intermediate_counter::find_first_missing() co cudf::size_type expression_parser::visit(literal const& expr) { - _node_count++; // Increment the node index - auto const data_type = expr.get_data_type(); // Resolve node type - auto device_view = expr.get_value(); // Construct a scalar device view - auto const literal_index = cudf::size_type(_literals.size()); // Push literal - _literals.push_back(device_view); - auto const source = detail::device_data_reference( - detail::device_data_reference_type::LITERAL, data_type, literal_index); // Push data reference - return add_data_reference(source); + if (_expression_count == 0) { + // Handle the trivial case of a literal as the entire expression. + return visit(operation(ast_operator::IDENTITY, expr)); + } else { + _expression_count++; // Increment the expression index + auto const data_type = expr.get_data_type(); // Resolve expression type + auto device_view = expr.get_value(); // Construct a scalar device view + auto const literal_index = cudf::size_type(_literals.size()); // Push literal + _literals.push_back(device_view); + auto const source = detail::device_data_reference(detail::device_data_reference_type::LITERAL, + data_type, + literal_index); // Push data reference + return add_data_reference(source); + } } cudf::size_type expression_parser::visit(column_reference const& expr) { - // Increment the node index - _node_count++; - // Resolve node type - cudf::data_type data_type; - if (expr.get_table_source() == table_reference::LEFT) { - data_type = expr.get_data_type(_left); + if (_expression_count == 0) { + // Handle the trivial case of a column reference as the entire expression. + return visit(operation(ast_operator::IDENTITY, expr)); } else { - if (_right.has_value()) { - data_type = expr.get_data_type(*_right); + // Increment the expression index + _expression_count++; + // Resolve expression type + cudf::data_type data_type; + if (expr.get_table_source() == table_reference::LEFT) { + data_type = expr.get_data_type(_left); } else { - CUDF_FAIL( - "Your expression contains a reference to the RIGHT table even though it will only be " - "evaluated on a single table (by convention, the LEFT table)."); + if (_right.has_value()) { + data_type = expr.get_data_type(*_right); + } else { + CUDF_FAIL( + "Your expression contains a reference to the RIGHT table even though it will only be " + "evaluated on a single table (by convention, the LEFT table)."); + } } + // Push data reference + auto const source = detail::device_data_reference(detail::device_data_reference_type::COLUMN, + data_type, + expr.get_column_index(), + expr.get_table_source()); + return add_data_reference(source); } - // Push data reference - auto const source = detail::device_data_reference(detail::device_data_reference_type::COLUMN, - data_type, - expr.get_column_index(), - expr.get_table_source()); - return add_data_reference(source); } -cudf::size_type expression_parser::visit(expression const& expr) +cudf::size_type expression_parser::visit(operation const& expr) { - // Increment the node index - auto const node_index = _node_count++; - // Visit children (operands) of this node + // Increment the expression index + auto const expression_index = _expression_count++; + // Visit children (operands) of this expression auto const operand_data_ref_indices = visit_operands(expr.get_operands()); // Resolve operand types auto data_ref = [this](auto const& index) { return _data_references[index].data_type; }; @@ -149,18 +160,18 @@ cudf::size_type expression_parser::visit(expression const& expr) _intermediate_counter.give(intermediate_index); } }); - // Resolve node type + // Resolve expression type auto const op = expr.get_operator(); auto const data_type = cudf::ast::detail::ast_operator_return_type(op, operand_types); _operators.push_back(op); // Push data reference auto const output = [&]() { - if (node_index == 0) { - // This node is the root. Output should be directed to the output column. + if (expression_index == 0) { + // This expression is the root. Output should be directed to the output column. return detail::device_data_reference( detail::device_data_reference_type::COLUMN, data_type, 0, table_reference::OUTPUT); } else { - // This node is not the root. Output is an intermediate value. + // This expression is not the root. Output is an intermediate value. // Ensure that the output type is fixed width and fits in the intermediate storage. if (!cudf::is_fixed_width(data_type)) { CUDF_FAIL( @@ -189,7 +200,7 @@ cudf::data_type expression_parser::output_type() const } std::vector expression_parser::visit_operands( - std::vector> operands) + std::vector> operands) { auto operand_data_reference_indices = std::vector(); for (auto const& operand : operands) { @@ -214,19 +225,6 @@ cudf::size_type expression_parser::add_data_reference(detail::device_data_refere } // namespace detail -cudf::size_type literal::accept(detail::expression_parser& visitor) const -{ - return visitor.visit(*this); -} -cudf::size_type column_reference::accept(detail::expression_parser& visitor) const -{ - return visitor.visit(*this); -} -cudf::size_type expression::accept(detail::expression_parser& visitor) const -{ - return visitor.visit(*this); -} - } // namespace ast } // namespace cudf diff --git a/cpp/src/ast/expressions.cpp b/cpp/src/ast/expressions.cpp new file mode 100644 index 00000000000..88cc6650d6c --- /dev/null +++ b/cpp/src/ast/expressions.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cudf { +namespace ast { + +operation::operation(ast_operator op, expression const& input) : op(op), operands({input}) +{ + if (cudf::ast::detail::ast_operator_arity(op) != 1) { + CUDF_FAIL("The provided operator is not a unary operator."); + } +} + +operation::operation(ast_operator op, expression const& left, expression const& right) + : op(op), operands({left, right}) +{ + if (cudf::ast::detail::ast_operator_arity(op) != 2) { + CUDF_FAIL("The provided operator is not a binary operator."); + } +} + +cudf::size_type literal::accept(detail::expression_parser& visitor) const +{ + return visitor.visit(*this); +} +cudf::size_type column_reference::accept(detail::expression_parser& visitor) const +{ + return visitor.visit(*this); +} +cudf::size_type operation::accept(detail::expression_parser& visitor) const +{ + return visitor.visit(*this); +} + +} // namespace ast + +} // namespace cudf diff --git a/cpp/src/join/conditional_join.cu b/cpp/src/join/conditional_join.cu index ee076d80140..bfabe99aaf9 100644 --- a/cpp/src/join/conditional_join.cu +++ b/cpp/src/join/conditional_join.cu @@ -15,7 +15,7 @@ */ #include -#include +#include #include #include #include @@ -38,7 +38,7 @@ std::pair>, std::unique_ptr>> conditional_join(table_view const& left, table_view const& right, - ast::expression binary_predicate, + ast::expression const& binary_predicate, null_equality compare_nulls, join_kind join_type, std::optional output_size, @@ -171,7 +171,7 @@ conditional_join(table_view const& left, std::size_t compute_conditional_join_output_size(table_view const& left, table_view const& right, - ast::expression binary_predicate, + ast::expression const& binary_predicate, null_equality compare_nulls, join_kind join_type, rmm::cuda_stream_view stream, @@ -248,9 +248,9 @@ std::size_t compute_conditional_join_output_size(table_view const& left, std::pair>, std::unique_ptr>> -conditional_inner_join(table_view left, - table_view right, - ast::expression binary_predicate, +conditional_inner_join(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) @@ -268,9 +268,9 @@ conditional_inner_join(table_view left, std::pair>, std::unique_ptr>> -conditional_left_join(table_view left, - table_view right, - ast::expression binary_predicate, +conditional_left_join(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) @@ -288,9 +288,9 @@ conditional_left_join(table_view left, std::pair>, std::unique_ptr>> -conditional_full_join(table_view left, - table_view right, - ast::expression binary_predicate, +conditional_full_join(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { @@ -306,9 +306,9 @@ conditional_full_join(table_view left, } std::unique_ptr> conditional_left_semi_join( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) @@ -326,9 +326,9 @@ std::unique_ptr> conditional_left_semi_join( } std::unique_ptr> conditional_left_anti_join( - table_view left, - table_view right, - ast::expression binary_predicate, + table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, std::optional output_size, rmm::mr::device_memory_resource* mr) @@ -345,9 +345,9 @@ std::unique_ptr> conditional_left_anti_join( .first); } -std::size_t conditional_inner_join_size(table_view left, - table_view right, - ast::expression binary_predicate, +std::size_t conditional_inner_join_size(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { @@ -361,9 +361,9 @@ std::size_t conditional_inner_join_size(table_view left, mr); } -std::size_t conditional_left_join_size(table_view left, - table_view right, - ast::expression binary_predicate, +std::size_t conditional_left_join_size(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { @@ -377,9 +377,9 @@ std::size_t conditional_left_join_size(table_view left, mr); } -std::size_t conditional_left_semi_join_size(table_view left, - table_view right, - ast::expression binary_predicate, +std::size_t conditional_left_semi_join_size(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { @@ -393,9 +393,9 @@ std::size_t conditional_left_semi_join_size(table_view left, mr)); } -std::size_t conditional_left_anti_join_size(table_view left, - table_view right, - ast::expression binary_predicate, +std::size_t conditional_left_anti_join_size(table_view const& left, + table_view const& right, + ast::expression const& binary_predicate, null_equality compare_nulls, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/src/join/conditional_join.hpp b/cpp/src/join/conditional_join.hpp index b5b49815381..5a3fe887838 100644 --- a/cpp/src/join/conditional_join.hpp +++ b/cpp/src/join/conditional_join.hpp @@ -17,7 +17,7 @@ #include "join_common_utils.hpp" -#include +#include #include #include @@ -45,7 +45,7 @@ std::pair>, std::unique_ptr>> conditional_join(table_view const& left, table_view const& right, - ast::expression binary_predicate, + ast::expression const& binary_predicate, null_equality compare_nulls, join_kind JoinKind, std::optional output_size = {}, @@ -68,7 +68,7 @@ conditional_join(table_view const& left, std::size_t compute_conditional_join_output_size( table_view const& left, table_view const& right, - ast::expression binary_predicate, + ast::expression const& binary_predicate, null_equality compare_nulls, join_kind JoinKind, rmm::cuda_stream_view stream = rmm::cuda_stream_default, diff --git a/cpp/src/transform/compute_column.cu b/cpp/src/transform/compute_column.cu index cd8196e555c..1466ee9ad27 100644 --- a/cpp/src/transform/compute_column.cu +++ b/cpp/src/transform/compute_column.cu @@ -16,7 +16,7 @@ #include #include -#include +#include #include #include #include @@ -78,7 +78,7 @@ __launch_bounds__(max_block_size) __global__ } } -std::unique_ptr compute_column(table_view const table, +std::unique_ptr compute_column(table_view const& table, ast::expression const& expr, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -135,7 +135,7 @@ std::unique_ptr compute_column(table_view const table, } // namespace detail -std::unique_ptr compute_column(table_view const table, +std::unique_ptr compute_column(table_view const& table, ast::expression const& expr, rmm::mr::device_memory_resource* mr) { diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index 19797d0ce2e..de6c9d486ec 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include #include #include @@ -47,6 +47,35 @@ constexpr cudf::test::debug_output_level verbosity{cudf::test::debug_output_leve struct TransformTest : public cudf::test::BaseFixture { }; +TEST_F(TransformTest, ColumnReference) +{ + auto c_0 = column_wrapper{3, 20, 1, 50}; + auto c_1 = column_wrapper{10, 7, 20, 0}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + + auto const& expected = c_0; + auto result = cudf::compute_column(table, col_ref_0); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + +TEST_F(TransformTest, Literal) +{ + auto c_0 = column_wrapper{3, 20, 1, 50}; + auto c_1 = column_wrapper{10, 7, 20, 0}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto literal_value = cudf::numeric_scalar(42); + auto literal = cudf::ast::literal(literal_value); + + auto expected = column_wrapper{42, 42, 42, 42}; + auto result = cudf::compute_column(table, literal); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + TEST_F(TransformTest, BasicAddition) { auto c_0 = column_wrapper{3, 20, 1, 50}; @@ -55,7 +84,7 @@ TEST_F(TransformTest, BasicAddition) auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); auto expected = column_wrapper{13, 27, 21, 50}; auto result = cudf::compute_column(table, expression); @@ -70,7 +99,7 @@ TEST_F(TransformTest, BasicAdditionLarge) auto table = cudf::table_view{{col, col}}; auto col_ref = cudf::ast::column_reference(0); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref, col_ref); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref, col_ref); auto b = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i * 2; }); auto expected = column_wrapper(b, b + 2000); @@ -87,7 +116,7 @@ TEST_F(TransformTest, LessComparator) auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); auto expected = column_wrapper{true, false, true, false}; auto result = cudf::compute_column(table, expression); @@ -105,7 +134,7 @@ TEST_F(TransformTest, LessComparatorLarge) auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); auto c = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i < 500; }); auto expected = column_wrapper(c, c + 2000); @@ -126,12 +155,12 @@ TEST_F(TransformTest, MultiLevelTreeArithmetic) auto col_ref_2 = cudf::ast::column_reference(2); auto expression_left_subtree = - cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); + cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); auto expression_right_subtree = - cudf::ast::expression(cudf::ast::ast_operator::SUB, col_ref_2, col_ref_0); + cudf::ast::operation(cudf::ast::ast_operator::SUB, col_ref_2, col_ref_0); - auto expression_tree = cudf::ast::expression( + auto expression_tree = cudf::ast::operation( cudf::ast::ast_operator::ADD, expression_left_subtree, expression_right_subtree); auto result = cudf::compute_column(table, expression_tree); @@ -142,8 +171,6 @@ TEST_F(TransformTest, MultiLevelTreeArithmetic) TEST_F(TransformTest, MultiLevelTreeArithmeticLarge) { - using namespace cudf::ast; - auto a = thrust::make_counting_iterator(0); auto b = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i + 1; }); auto c = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i * 2; }); @@ -152,13 +179,15 @@ TEST_F(TransformTest, MultiLevelTreeArithmeticLarge) auto c_2 = column_wrapper(c, c + 2000); auto table = cudf::table_view{{c_0, c_1, c_2}}; - auto col_ref_0 = column_reference(0); - auto col_ref_1 = column_reference(1); - auto col_ref_2 = column_reference(2); + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto col_ref_2 = cudf::ast::column_reference(2); - auto expr_left_subtree = expression(cudf::ast::ast_operator::MUL, col_ref_0, col_ref_1); - auto expr_right_subtree = expression(cudf::ast::ast_operator::ADD, col_ref_2, col_ref_0); - auto expr_tree = expression(ast_operator::SUB, expr_left_subtree, expr_right_subtree); + auto expr_left_subtree = cudf::ast::operation(cudf::ast::ast_operator::MUL, col_ref_0, col_ref_1); + auto expr_right_subtree = + cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_2, col_ref_0); + auto expr_tree = + cudf::ast::operation(cudf::ast::ast_operator::SUB, expr_left_subtree, expr_right_subtree); auto result = cudf::compute_column(table, expr_tree); auto calc = [](auto i) { return (i * (i + 1)) - (i + (i * 2)); }; @@ -180,10 +209,10 @@ TEST_F(TransformTest, ImbalancedTreeArithmetic) auto col_ref_2 = cudf::ast::column_reference(2); auto expression_right_subtree = - cudf::ast::expression(cudf::ast::ast_operator::MUL, col_ref_0, col_ref_1); + cudf::ast::operation(cudf::ast::ast_operator::MUL, col_ref_0, col_ref_1); auto expression_tree = - cudf::ast::expression(cudf::ast::ast_operator::SUB, col_ref_2, expression_right_subtree); + cudf::ast::operation(cudf::ast::ast_operator::SUB, col_ref_2, expression_right_subtree); auto result = cudf::compute_column(table, expression_tree); auto expected = @@ -204,12 +233,12 @@ TEST_F(TransformTest, MultiLevelTreeComparator) auto col_ref_2 = cudf::ast::column_reference(2); auto expression_left_subtree = - cudf::ast::expression(cudf::ast::ast_operator::GREATER_EQUAL, col_ref_0, col_ref_1); + cudf::ast::operation(cudf::ast::ast_operator::GREATER_EQUAL, col_ref_0, col_ref_1); auto expression_right_subtree = - cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_2, col_ref_0); + cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_2, col_ref_0); - auto expression_tree = cudf::ast::expression( + auto expression_tree = cudf::ast::operation( cudf::ast::ast_operator::LOGICAL_AND, expression_left_subtree, expression_right_subtree); auto result = cudf::compute_column(table, expression_tree); @@ -228,9 +257,9 @@ TEST_F(TransformTest, MultiTypeOperationFailure) auto col_ref_1 = cudf::ast::column_reference(1); auto expression_0_plus_1 = - cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); + cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); auto expression_1_plus_0 = - cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref_1, col_ref_0); + cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_1, col_ref_0); // Operations on different types are not allowed EXPECT_THROW(cudf::compute_column(table, expression_0_plus_1), cudf::logic_error); @@ -246,7 +275,7 @@ TEST_F(TransformTest, LiteralComparison) auto literal_value = cudf::numeric_scalar(41); auto literal = cudf::ast::literal(literal_value); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_0, literal); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_0, literal); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{false, false, false, true}; @@ -261,7 +290,7 @@ TEST_F(TransformTest, UnaryNot) auto col_ref_0 = cudf::ast::column_reference(0); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::NOT, col_ref_0); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::NOT, col_ref_0); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{false, true, false, false}; @@ -277,17 +306,17 @@ TEST_F(TransformTest, UnaryTrigonometry) auto col_ref_0 = cudf::ast::column_reference(0); auto expected_sin = column_wrapper{0.0, std::sqrt(2) / 2, std::sqrt(3.0) / 2.0}; - auto expression_sin = cudf::ast::expression(cudf::ast::ast_operator::SIN, col_ref_0); + auto expression_sin = cudf::ast::operation(cudf::ast::ast_operator::SIN, col_ref_0); auto result_sin = cudf::compute_column(table, expression_sin); cudf::test::expect_columns_equivalent(expected_sin, result_sin->view(), verbosity); auto expected_cos = column_wrapper{1.0, std::sqrt(2) / 2, 0.5}; - auto expression_cos = cudf::ast::expression(cudf::ast::ast_operator::COS, col_ref_0); + auto expression_cos = cudf::ast::operation(cudf::ast::ast_operator::COS, col_ref_0); auto result_cos = cudf::compute_column(table, expression_cos); cudf::test::expect_columns_equivalent(expected_cos, result_cos->view(), verbosity); auto expected_tan = column_wrapper{0.0, 1.0, std::sqrt(3.0)}; - auto expression_tan = cudf::ast::expression(cudf::ast::ast_operator::TAN, col_ref_0); + auto expression_tan = cudf::ast::operation(cudf::ast::ast_operator::TAN, col_ref_0); auto result_tan = cudf::compute_column(table, expression_tan); cudf::test::expect_columns_equivalent(expected_tan, result_tan->view(), verbosity); } @@ -295,8 +324,8 @@ TEST_F(TransformTest, UnaryTrigonometry) TEST_F(TransformTest, ArityCheckFailure) { auto col_ref_0 = cudf::ast::column_reference(0); - EXPECT_THROW(cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref_0), cudf::logic_error); - EXPECT_THROW(cudf::ast::expression(cudf::ast::ast_operator::ABS, col_ref_0, col_ref_0), + EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), cudf::logic_error); + EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ABS, col_ref_0, col_ref_0), cudf::logic_error); } @@ -308,7 +337,7 @@ TEST_F(TransformTest, StringComparison) auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); auto expected = column_wrapper{true, false, true, false}; auto result = cudf::compute_column(table, expression); @@ -322,7 +351,7 @@ TEST_F(TransformTest, CopyColumn) auto table = cudf::table_view{{c_0}}; auto col_ref_0 = cudf::ast::column_reference(0); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::IDENTITY, col_ref_0); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::IDENTITY, col_ref_0); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{3, 0, 1, 50}; @@ -338,7 +367,7 @@ TEST_F(TransformTest, CopyLiteral) auto literal_value = cudf::numeric_scalar(-123); auto literal = cudf::ast::literal(literal_value); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::IDENTITY, literal); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::IDENTITY, literal); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{-123, -123, -123, -123}; @@ -355,7 +384,7 @@ TEST_F(TransformTest, TrueDiv) auto literal_value = cudf::numeric_scalar(2); auto literal = cudf::ast::literal(literal_value); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::TRUE_DIV, col_ref_0, literal); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::TRUE_DIV, col_ref_0, literal); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{1.5, 0.0, 0.5, 25.0}; @@ -372,7 +401,7 @@ TEST_F(TransformTest, FloorDiv) auto literal_value = cudf::numeric_scalar(2.0); auto literal = cudf::ast::literal(literal_value); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::FLOOR_DIV, col_ref_0, literal); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::FLOOR_DIV, col_ref_0, literal); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{1.0, 0.0, 0.0, 25.0}; @@ -389,7 +418,7 @@ TEST_F(TransformTest, Mod) auto literal_value = cudf::numeric_scalar(2.0); auto literal = cudf::ast::literal(literal_value); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::MOD, col_ref_0, literal); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::MOD, col_ref_0, literal); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{1.0, 0.0, -1.0, 0.0}; @@ -406,7 +435,7 @@ TEST_F(TransformTest, PyMod) auto literal_value = cudf::numeric_scalar(2.0); auto literal = cudf::ast::literal(literal_value); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::PYMOD, col_ref_0, literal); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::PYMOD, col_ref_0, literal); auto result = cudf::compute_column(table, expression); auto expected = column_wrapper{1.0, 0.0, 1.0, 0.0}; @@ -422,7 +451,7 @@ TEST_F(TransformTest, BasicAdditionNulls) auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_1); auto expected = column_wrapper{{0, 0, 0, 50}, {0, 0, 0, 1}}; auto result = cudf::compute_column(table, expression); @@ -447,7 +476,7 @@ TEST_F(TransformTest, BasicAdditionLargeNulls) auto table = cudf::table_view{{col}}; auto col_ref = cudf::ast::column_reference(0); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::ADD, col_ref, col_ref); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref, col_ref); auto b = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return i * 2; }); auto expected = column_wrapper(b, b + N, validities.begin()); diff --git a/cpp/tests/join/conditional_join_tests.cu b/cpp/tests/join/conditional_join_tests.cu index e16e1ec7de8..8018d613e05 100644 --- a/cpp/tests/join/conditional_join_tests.cu +++ b/cpp/tests/join/conditional_join_tests.cu @@ -14,8 +14,7 @@ * limitations under the License. */ -#include -#include +#include #include #include #include @@ -50,7 +49,7 @@ const auto col_ref_right_1 = cudf::ast::column_reference(1, cudf::ast::table_ref // Common expressions. auto left_zero_eq_right_zero = - cudf::ast::expression(cudf::ast::ast_operator::EQUAL, col_ref_left_0, col_ref_right_0); + cudf::ast::operation(cudf::ast::ast_operator::EQUAL, col_ref_left_0, col_ref_right_0); } // namespace /** @@ -147,7 +146,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { */ void test(std::vector> left_data, std::vector> right_data, - cudf::ast::expression predicate, + cudf::ast::operation predicate, std::vector> expected_outputs) { // Note that we need to maintain the column wrappers otherwise the @@ -174,7 +173,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { void test_nulls(std::vector, std::vector>> left_data, std::vector, std::vector>> right_data, - cudf::ast::expression predicate, + cudf::ast::operation predicate, std::vector> expected_outputs) { // Note that we need to maintain the column wrappers otherwise the @@ -252,7 +251,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { */ virtual std::pair>, std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) = 0; + join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) = 0; /** * This method must be implemented by subclasses for specific types of joins. @@ -261,7 +260,7 @@ struct ConditionalJoinPairReturnTest : public ConditionalJoinTest { */ virtual std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) = 0; + cudf::ast::operation predicate) = 0; /** * This method must be implemented by subclasses for specific types of joins. @@ -280,14 +279,14 @@ template struct ConditionalInnerJoinTest : public ConditionalJoinPairReturnTest { std::pair>, std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) override + join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override { return cudf::conditional_inner_join(left, right, predicate); } std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) override + cudf::ast::operation predicate) override { return cudf::conditional_inner_join_size(left, right, predicate); } @@ -336,7 +335,7 @@ TYPED_TEST(ConditionalInnerJoinTest, TestTwoColumnThreeRowSomeEqual) TYPED_TEST(ConditionalInnerJoinTest, TestNotComparison) { auto col_ref_0 = cudf::ast::column_reference(0); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::NOT, col_ref_0); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::NOT, col_ref_0); this->test({{0, 1, 2}}, {{3, 4, 5}}, expression, {{0, 0}, {0, 1}, {0, 2}}); }; @@ -345,7 +344,7 @@ TYPED_TEST(ConditionalInnerJoinTest, TestGreaterComparison) { auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); this->test({{0, 1, 2}}, {{1, 0, 0}}, expression, {{1, 1}, {1, 2}, {2, 0}, {2, 1}, {2, 2}}); }; @@ -354,7 +353,7 @@ TYPED_TEST(ConditionalInnerJoinTest, TestGreaterTwoColumnComparison) { auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1, cudf::ast::table_reference::RIGHT); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); this->test({{0, 1, 2}, {0, 0, 0}}, {{0, 0, 0}, {1, 0, 0}}, @@ -366,7 +365,7 @@ TYPED_TEST(ConditionalInnerJoinTest, TestGreaterDifferentNumberColumnComparison) { auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1, cudf::ast::table_reference::RIGHT); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); this->test( {{0, 1, 2}}, {{0, 0, 0}, {1, 0, 0}}, expression, {{1, 1}, {1, 2}, {2, 0}, {2, 1}, {2, 2}}); @@ -376,7 +375,7 @@ TYPED_TEST(ConditionalInnerJoinTest, TestGreaterDifferentNumberColumnDifferentSi { auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(1, cudf::ast::table_reference::RIGHT); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_0, col_ref_1); this->test({{0, 1}}, {{0, 0, 0}, {1, 0, 0}}, expression, {{1, 1}, {1, 2}}); }; @@ -387,14 +386,14 @@ TYPED_TEST(ConditionalInnerJoinTest, TestComplexConditionMultipleColumns) auto col_ref_0 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT); auto scalar_1 = cudf::numeric_scalar(1); auto literal_1 = cudf::ast::literal(scalar_1); - auto left_0_equal_1 = cudf::ast::expression(cudf::ast::ast_operator::EQUAL, col_ref_0, literal_1); + auto left_0_equal_1 = cudf::ast::operation(cudf::ast::ast_operator::EQUAL, col_ref_0, literal_1); auto col_ref_1 = cudf::ast::column_reference(1, cudf::ast::table_reference::RIGHT); auto comparison_filter = - cudf::ast::expression(cudf::ast::ast_operator::LESS, col_ref_1, col_ref_0); + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_1, col_ref_0); auto expression = - cudf::ast::expression(cudf::ast::ast_operator::LOGICAL_AND, left_0_equal_1, comparison_filter); + cudf::ast::operation(cudf::ast::ast_operator::LOGICAL_AND, left_0_equal_1, comparison_filter); this->test({{0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}, {{0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2}, @@ -408,9 +407,9 @@ TYPED_TEST(ConditionalInnerJoinTest, TestSymmetry) { auto col_ref_0 = cudf::ast::column_reference(0); auto col_ref_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); - auto expression = cudf::ast::expression(cudf::ast::ast_operator::GREATER, col_ref_1, col_ref_0); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_1, col_ref_0); auto expression_reverse = - cudf::ast::expression(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_1); this->test( {{0, 1, 2}}, {{1, 2, 3}}, expression, {{0, 0}, {0, 1}, {0, 2}, {1, 1}, {1, 2}, {2, 2}}); @@ -462,14 +461,14 @@ template struct ConditionalLeftJoinTest : public ConditionalJoinPairReturnTest { std::pair>, std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) override + join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override { return cudf::conditional_left_join(left, right, predicate); } std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) override + cudf::ast::operation predicate) override { return cudf::conditional_left_join_size(left, right, predicate); } @@ -525,14 +524,14 @@ template struct ConditionalFullJoinTest : public ConditionalJoinPairReturnTest { std::pair>, std::unique_ptr>> - join(cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) override + join(cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override { return cudf::conditional_full_join(left, right, predicate); } std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) override + cudf::ast::operation predicate) override { // Full joins don't actually support size calculations, but to support a // uniform testing framework we just calculate it from the result of doing @@ -610,7 +609,7 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { */ void test(std::vector> left_data, std::vector> right_data, - cudf::ast::expression predicate, + cudf::ast::operation predicate, std::vector expected_outputs) { auto [left_wrappers, right_wrappers, left_columns, right_columns, left, right] = @@ -661,7 +660,7 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { * conditional join API. */ virtual std::unique_ptr> join( - cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) = 0; + cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) = 0; /** * This method must be implemented by subclasses for specific types of joins. @@ -670,7 +669,7 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { */ virtual std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) = 0; + cudf::ast::operation predicate) = 0; /** * This method must be implemented by subclasses for specific types of joins. @@ -687,14 +686,14 @@ struct ConditionalJoinSingleReturnTest : public ConditionalJoinTest { template struct ConditionalLeftSemiJoinTest : public ConditionalJoinSingleReturnTest { std::unique_ptr> join( - cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) override + cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override { return cudf::conditional_left_semi_join(left, right, predicate); } std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) override + cudf::ast::operation predicate) override { return cudf::conditional_left_semi_join_size(left, right, predicate); } @@ -745,14 +744,14 @@ TYPED_TEST(ConditionalLeftSemiJoinTest, TestCompareRandomToHash) template struct ConditionalLeftAntiJoinTest : public ConditionalJoinSingleReturnTest { std::unique_ptr> join( - cudf::table_view left, cudf::table_view right, cudf::ast::expression predicate) override + cudf::table_view left, cudf::table_view right, cudf::ast::operation predicate) override { return cudf::conditional_left_anti_join(left, right, predicate); } std::size_t join_size(cudf::table_view left, cudf::table_view right, - cudf::ast::expression predicate) override + cudf::ast::operation predicate) override { return cudf::conditional_left_anti_join_size(left, right, predicate); } diff --git a/java/src/main/java/ai/rapids/cudf/ast/AstNode.java b/java/src/main/java/ai/rapids/cudf/ast/AstExpression.java similarity index 82% rename from java/src/main/java/ai/rapids/cudf/ast/AstNode.java rename to java/src/main/java/ai/rapids/cudf/ast/AstExpression.java index 8160462de98..5ac15f714f0 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/AstNode.java +++ b/java/src/main/java/ai/rapids/cudf/ast/AstExpression.java @@ -17,14 +17,15 @@ package ai.rapids.cudf.ast; import java.nio.ByteBuffer; +import java.nio.ByteOrder; /** Base class of every node in an AST */ -public abstract class AstNode { +public abstract class AstExpression { /** * Enumeration for the types of AST nodes that can appear in a serialized AST. * NOTE: This must be kept in sync with the `jni_serialized_node_type` in CompiledExpression.cpp! */ - protected enum NodeType { + protected enum ExpressionType { VALID_LITERAL(0), NULL_LITERAL(1), COLUMN_REFERENCE(2), @@ -33,7 +34,7 @@ protected enum NodeType { private final byte nativeId; - NodeType(int nativeId) { + ExpressionType(int nativeId) { this.nativeId = (byte) nativeId; assert this.nativeId == nativeId; } @@ -49,6 +50,14 @@ void serialize(ByteBuffer bb) { } } + public CompiledExpression compile() { + int size = getSerializedSize(); + ByteBuffer bb = ByteBuffer.allocate(size); + bb.order(ByteOrder.nativeOrder()); + serialize(bb); + return new CompiledExpression(bb.array()); + } + /** Get the size in bytes of the serialized form of this node and all child nodes */ abstract int getSerializedSize(); diff --git a/java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperation.java similarity index 72% rename from java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java rename to java/src/main/java/ai/rapids/cudf/ast/BinaryOperation.java index ed4f95b01e1..c39c1c3a1c5 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/BinaryExpression.java +++ b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperation.java @@ -18,13 +18,13 @@ import java.nio.ByteBuffer; -/** A binary expression consisting of an operator and two operands. */ -public class BinaryExpression extends Expression { +/** A binary operation consisting of an operator and two operands. */ +public class BinaryOperation extends AstExpression { private final BinaryOperator op; - private final AstNode leftInput; - private final AstNode rightInput; + private final AstExpression leftInput; + private final AstExpression rightInput; - public BinaryExpression(BinaryOperator op, AstNode leftInput, AstNode rightInput) { + public BinaryOperation(BinaryOperator op, AstExpression leftInput, AstExpression rightInput) { this.op = op; this.leftInput = leftInput; this.rightInput = rightInput; @@ -32,7 +32,7 @@ public BinaryExpression(BinaryOperator op, AstNode leftInput, AstNode rightInput @Override int getSerializedSize() { - return NodeType.BINARY_EXPRESSION.getSerializedSize() + + return ExpressionType.BINARY_EXPRESSION.getSerializedSize() + op.getSerializedSize() + leftInput.getSerializedSize() + rightInput.getSerializedSize(); @@ -40,7 +40,7 @@ int getSerializedSize() { @Override void serialize(ByteBuffer bb) { - NodeType.BINARY_EXPRESSION.serialize(bb); + ExpressionType.BINARY_EXPRESSION.serialize(bb); op.serialize(bb); leftInput.serialize(bb); rightInput.serialize(bb); diff --git a/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java index 12e4d985658..595badb14b6 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java +++ b/java/src/main/java/ai/rapids/cudf/ast/BinaryOperator.java @@ -19,7 +19,7 @@ import java.nio.ByteBuffer; /** - * Enumeration of AST operations that can appear in a binary expression. + * Enumeration of AST operators that can appear in a binary operation. * NOTE: This must be kept in sync with `jni_to_binary_operator` in CompiledExpression.cpp! */ public enum BinaryOperator { diff --git a/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java b/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java index 34e4064e23b..4860a088a83 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java +++ b/java/src/main/java/ai/rapids/cudf/ast/ColumnReference.java @@ -19,7 +19,7 @@ import java.nio.ByteBuffer; /** A reference to a column in an input table. */ -public final class ColumnReference extends AstNode { +public final class ColumnReference extends AstExpression { private final int columnIndex; private final TableReference tableSource; @@ -37,14 +37,14 @@ public ColumnReference(int columnIndex, TableReference tableSource) { @Override int getSerializedSize() { // node type + table ref + column index - return NodeType.COLUMN_REFERENCE.getSerializedSize() + + return ExpressionType.COLUMN_REFERENCE.getSerializedSize() + tableSource.getSerializedSize() + Integer.BYTES; } @Override void serialize(ByteBuffer bb) { - NodeType.COLUMN_REFERENCE.serialize(bb); + ExpressionType.COLUMN_REFERENCE.serialize(bb); tableSource.serialize(bb); bb.putInt(columnIndex); } diff --git a/java/src/main/java/ai/rapids/cudf/ast/Expression.java b/java/src/main/java/ai/rapids/cudf/ast/Expression.java deleted file mode 100644 index 8d391298cef..00000000000 --- a/java/src/main/java/ai/rapids/cudf/ast/Expression.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package ai.rapids.cudf.ast; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -/** Base class of every AST expression. */ -public abstract class Expression extends AstNode { - public CompiledExpression compile() { - int size = getSerializedSize(); - ByteBuffer bb = ByteBuffer.allocate(size); - bb.order(ByteOrder.nativeOrder()); - serialize(bb); - return new CompiledExpression(bb.array()); - } -} diff --git a/java/src/main/java/ai/rapids/cudf/ast/Literal.java b/java/src/main/java/ai/rapids/cudf/ast/Literal.java index be306cd99c4..b93efce8c94 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/Literal.java +++ b/java/src/main/java/ai/rapids/cudf/ast/Literal.java @@ -22,7 +22,7 @@ import java.nio.ByteOrder; /** A literal value in an AST expression. */ -public final class Literal extends AstNode { +public final class Literal extends AstExpression { private final DType type; private final byte[] serializedValue; @@ -207,8 +207,8 @@ public static Literal ofDurationFromLong(DType type, Long value) { @Override int getSerializedSize() { - NodeType nodeType = serializedValue != null - ? NodeType.VALID_LITERAL : NodeType.NULL_LITERAL; + ExpressionType nodeType = serializedValue != null + ? ExpressionType.VALID_LITERAL : ExpressionType.NULL_LITERAL; int size = nodeType.getSerializedSize() + getDataTypeSerializedSize(); if (serializedValue != null) { size += serializedValue.length; @@ -218,8 +218,8 @@ int getSerializedSize() { @Override void serialize(ByteBuffer bb) { - NodeType nodeType = serializedValue != null - ? NodeType.VALID_LITERAL : NodeType.NULL_LITERAL; + ExpressionType nodeType = serializedValue != null + ? ExpressionType.VALID_LITERAL : ExpressionType.NULL_LITERAL; nodeType.serialize(bb); serializeDataType(bb); if (serializedValue != null) { diff --git a/java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperation.java similarity index 73% rename from java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java rename to java/src/main/java/ai/rapids/cudf/ast/UnaryOperation.java index fa8e70266ac..03c4c45afd4 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/UnaryExpression.java +++ b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperation.java @@ -18,26 +18,26 @@ import java.nio.ByteBuffer; -/** A unary expression consisting of an operator and an operand. */ -public final class UnaryExpression extends Expression { +/** A unary operation consisting of an operator and an operand. */ +public final class UnaryOperation extends AstExpression { private final UnaryOperator op; - private final AstNode input; + private final AstExpression input; - public UnaryExpression(UnaryOperator op, AstNode input) { + public UnaryOperation(UnaryOperator op, AstExpression input) { this.op = op; this.input = input; } @Override int getSerializedSize() { - return NodeType.UNARY_EXPRESSION.getSerializedSize() + + return ExpressionType.UNARY_EXPRESSION.getSerializedSize() + op.getSerializedSize() + input.getSerializedSize(); } @Override void serialize(ByteBuffer bb) { - NodeType.UNARY_EXPRESSION.serialize(bb); + ExpressionType.UNARY_EXPRESSION.serialize(bb); op.serialize(bb); input.serialize(bb); } diff --git a/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java index c3f193d06b4..9ef18dbd75d 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java +++ b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java @@ -19,7 +19,7 @@ import java.nio.ByteBuffer; /** - * Enumeration of AST operations that can appear in a unary expression. + * Enumeration of AST operators that can appear in a unary operation. * NOTE: This must be kept in sync with `jni_to_unary_operator` in CompiledExpression.cpp! */ public enum UnaryOperator { diff --git a/java/src/main/native/src/CompiledExpression.cpp b/java/src/main/native/src/CompiledExpression.cpp index fe57f79c955..470464f35c8 100644 --- a/java/src/main/native/src/CompiledExpression.cpp +++ b/java/src/main/native/src/CompiledExpression.cpp @@ -18,8 +18,7 @@ #include #include -#include -#include +#include #include #include #include @@ -104,15 +103,15 @@ class jni_serialized_ast { }; /** - * Enumeration of the AST node types that can appear in the serialized data. + * Enumeration of the AST expression types that can appear in the serialized data. * NOTE: This must be kept in sync with the NodeType enumeration in AstNode.java! */ -enum class jni_serialized_node_type : int8_t { +enum class jni_serialized_expression_type : int8_t { VALID_LITERAL = 0, NULL_LITERAL = 1, COLUMN_REFERENCE = 2, - UNARY_EXPRESSION = 3, - BINARY_EXPRESSION = 4 + UNARY_OPERATION = 3, + BINARY_OPERATION = 4 }; /** @@ -276,41 +275,42 @@ cudf::ast::column_reference &compile_column_reference(cudf::jni::ast::compiled_e } // forward declaration -cudf::ast::detail::node &compile_node(cudf::jni::ast::compiled_expr &compiled_expr, - jni_serialized_ast &jni_ast); +cudf::ast::expression &compile_expression(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast); /** Decode a serialized AST unary expression */ -cudf::ast::expression &compile_unary_expression(cudf::jni::ast::compiled_expr &compiled_expr, - jni_serialized_ast &jni_ast) { +cudf::ast::operation &compile_unary_expression(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { auto const ast_op = jni_to_unary_operator(jni_ast.read_byte()); - cudf::ast::detail::node &child_node = compile_node(compiled_expr, jni_ast); - return compiled_expr.add_expression(std::make_unique(ast_op, child_node)); + cudf::ast::expression &child_expression = compile_expression(compiled_expr, jni_ast); + return compiled_expr.add_operation( + std::make_unique(ast_op, child_expression)); } /** Decode a serialized AST binary expression */ -cudf::ast::expression &compile_binary_expression(cudf::jni::ast::compiled_expr &compiled_expr, - jni_serialized_ast &jni_ast) { +cudf::ast::operation &compile_binary_expression(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { auto const ast_op = jni_to_binary_operator(jni_ast.read_byte()); - cudf::ast::detail::node &left_child = compile_node(compiled_expr, jni_ast); - cudf::ast::detail::node &right_child = compile_node(compiled_expr, jni_ast); - return compiled_expr.add_expression( - std::make_unique(ast_op, left_child, right_child)); + cudf::ast::expression &left_child = compile_expression(compiled_expr, jni_ast); + cudf::ast::expression &right_child = compile_expression(compiled_expr, jni_ast); + return compiled_expr.add_operation( + std::make_unique(ast_op, left_child, right_child)); } -/** Decode a serialized AST node by reading the node type and dispatching */ -cudf::ast::detail::node &compile_node(cudf::jni::ast::compiled_expr &compiled_expr, - jni_serialized_ast &jni_ast) { - auto const node_type = static_cast(jni_ast.read_byte()); - switch (node_type) { - case jni_serialized_node_type::VALID_LITERAL: +/** Decode a serialized AST expression by reading the expression type and dispatching */ +cudf::ast::expression &compile_expression(cudf::jni::ast::compiled_expr &compiled_expr, + jni_serialized_ast &jni_ast) { + auto const expression_type = static_cast(jni_ast.read_byte()); + switch (expression_type) { + case jni_serialized_expression_type::VALID_LITERAL: return compile_literal(true, compiled_expr, jni_ast); - case jni_serialized_node_type::NULL_LITERAL: + case jni_serialized_expression_type::NULL_LITERAL: return compile_literal(false, compiled_expr, jni_ast); - case jni_serialized_node_type::COLUMN_REFERENCE: + case jni_serialized_expression_type::COLUMN_REFERENCE: return compile_column_reference(compiled_expr, jni_ast); - case jni_serialized_node_type::UNARY_EXPRESSION: + case jni_serialized_expression_type::UNARY_OPERATION: return compile_unary_expression(compiled_expr, jni_ast); - case jni_serialized_node_type::BINARY_EXPRESSION: + case jni_serialized_expression_type::BINARY_OPERATION: return compile_binary_expression(compiled_expr, jni_ast); default: throw std::invalid_argument("data is not a serialized AST expression"); } @@ -319,16 +319,7 @@ cudf::ast::detail::node &compile_node(cudf::jni::ast::compiled_expr &compiled_ex /** Decode a serialized AST into a native libcudf AST and associated resources */ std::unique_ptr compile_serialized_ast(jni_serialized_ast &jni_ast) { auto jni_expr_ptr = std::make_unique(); - auto const node_type = static_cast(jni_ast.read_byte()); - switch (node_type) { - case jni_serialized_node_type::UNARY_EXPRESSION: - (void)compile_unary_expression(*jni_expr_ptr, jni_ast); - break; - case jni_serialized_node_type::BINARY_EXPRESSION: - (void)compile_binary_expression(*jni_expr_ptr, jni_ast); - break; - default: throw std::invalid_argument("data is not a serialized AST expression"); - } + (void)compile_expression(*jni_expr_ptr, jni_ast); if (!jni_ast.at_eof()) { throw std::invalid_argument("Extra bytes at end of serialized AST"); diff --git a/java/src/main/native/src/jni_compiled_expr.hpp b/java/src/main/native/src/jni_compiled_expr.hpp index e42e5a37fba..74010f71011 100644 --- a/java/src/main/native/src/jni_compiled_expr.hpp +++ b/java/src/main/native/src/jni_compiled_expr.hpp @@ -32,12 +32,6 @@ namespace ast { * base AST node type. Then we do not have to track every AST node type separately. */ class compiled_expr { - /** All literal nodes within the expression tree */ - std::vector> literals; - - /** All column reference nodes within the expression tree */ - std::vector> column_refs; - /** All expression nodes within the expression tree */ std::vector> expressions; @@ -47,20 +41,20 @@ class compiled_expr { public: cudf::ast::literal &add_literal(std::unique_ptr literal_ptr, std::unique_ptr scalar_ptr) { - literals.push_back(std::move(literal_ptr)); + expressions.push_back(std::move(literal_ptr)); scalars.push_back(std::move(scalar_ptr)); - return *literals.back(); + return static_cast(*expressions.back()); } cudf::ast::column_reference & add_column_ref(std::unique_ptr ref_ptr) { - column_refs.push_back(std::move(ref_ptr)); - return *column_refs.back(); + expressions.push_back(std::move(ref_ptr)); + return static_cast(*expressions.back()); } - cudf::ast::expression &add_expression(std::unique_ptr expr_ptr) { + cudf::ast::operation &add_operation(std::unique_ptr expr_ptr) { expressions.push_back(std::move(expr_ptr)); - return *expressions.back(); + return static_cast(*expressions.back()); } /** Return the expression node at the top of the tree */ diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 725abd9486d..8e4e3df612b 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -25,7 +25,7 @@ import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; -import ai.rapids.cudf.ast.BinaryExpression; +import ai.rapids.cudf.ast.BinaryOperation; import ai.rapids.cudf.ast.BinaryOperator; import ai.rapids.cudf.ast.ColumnReference; import ai.rapids.cudf.ast.CompiledExpression; @@ -1503,7 +1503,7 @@ void testLeftJoinGatherMapsNulls() { @Test void testConditionalLeftJoinGatherMaps() { final int inv = Integer.MIN_VALUE; - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1529,7 +1529,7 @@ void testConditionalLeftJoinGatherMaps() { @Test void testConditionalLeftJoinGatherMapsNulls() { final int inv = Integer.MIN_VALUE; - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1557,7 +1557,7 @@ void testConditionalLeftJoinGatherMapsNulls() { @Test void testConditionalLeftJoinGatherMapsWithCount() { final int inv = Integer.MIN_VALUE; - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1585,7 +1585,7 @@ void testConditionalLeftJoinGatherMapsWithCount() { @Test void testConditionalLeftJoinGatherMapsNullsWithCount() { final int inv = Integer.MIN_VALUE; - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1656,7 +1656,7 @@ void testInnerJoinGatherMapsNulls() { @Test void testConditionalInnerJoinGatherMaps() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1681,7 +1681,7 @@ void testConditionalInnerJoinGatherMaps() { @Test void testConditionalInnerJoinGatherMapsNulls() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1708,7 +1708,7 @@ void testConditionalInnerJoinGatherMapsNulls() { @Test void testConditionalInnerJoinGatherMapsWithCount() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1735,7 +1735,7 @@ void testConditionalInnerJoinGatherMapsWithCount() { @Test void testConditionalInnerJoinGatherMapsNullsWithCount() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1809,7 +1809,7 @@ void testFullJoinGatherMapsNulls() { @Test void testConditionalFullJoinGatherMaps() { final int inv = Integer.MIN_VALUE; - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1835,7 +1835,7 @@ void testConditionalFullJoinGatherMaps() { @Test void testConditionalFullJoinGatherMapsNulls() { final int inv = Integer.MIN_VALUE; - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1890,7 +1890,7 @@ void testLeftSemiJoinGatherMapNulls() { @Test void testConditionalLeftSemiJoinGatherMap() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1908,7 +1908,7 @@ void testConditionalLeftSemiJoinGatherMap() { @Test void testConditionalLeftSemiJoinGatherMapNulls() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -1928,7 +1928,7 @@ void testConditionalLeftSemiJoinGatherMapNulls() { @Test void testConditionalLeftSemiJoinGatherMapWithCount() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1950,7 +1950,7 @@ void testConditionalLeftSemiJoinGatherMapWithCount() { @Test void testConditionalLeftSemiJoinGatherMapNullsWithCount() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2002,7 +2002,7 @@ void testAntiSemiJoinGatherMapNulls() { @Test void testConditionalLeftAntiJoinGatherMap() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -2020,7 +2020,7 @@ void testConditionalLeftAntiJoinGatherMap() { @Test void testConditionalAntiSemiJoinGatherMapNulls() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() @@ -2040,7 +2040,7 @@ void testConditionalAntiSemiJoinGatherMapNulls() { @Test void testConditionalLeftAntiJoinGatherMapWithCount() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -2062,7 +2062,7 @@ void testConditionalLeftAntiJoinGatherMapWithCount() { @Test void testConditionalAntiSemiJoinGatherMapNullsWithCount() { - BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + BinaryOperation expr = new BinaryOperation(BinaryOperator.EQUAL, new ColumnReference(0, TableReference.LEFT), new ColumnReference(0, TableReference.RIGHT)); try (Table left = new Table.TestBuilder() diff --git a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java index 177abe9d6e3..13af9aff682 100644 --- a/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java +++ b/java/src/test/java/ai/rapids/cudf/ast/CompiledExpressionTest.java @@ -42,16 +42,14 @@ public class CompiledExpressionTest extends CudfTestBase { public void testColumnReferenceTransform() { try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build()) { // use an implicit table reference - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, - new ColumnReference(1)); + ColumnReference expr = new ColumnReference(1); try (CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t)) { assertColumnsAreEqual(t.getColumn(1), actual); } // use an explicit table reference - expr = new UnaryExpression(UnaryOperator.IDENTITY, - new ColumnReference(1, TableReference.LEFT)); + expr = new ColumnReference(1, TableReference.LEFT); try (CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t)) { assertColumnsAreEqual(t.getColumn(1), actual); @@ -62,8 +60,7 @@ public void testColumnReferenceTransform() { @Test public void testInvalidColumnReferenceTransform() { // Verify that computeColumn throws when passed an expression operating on TableReference.RIGHT. - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, - new ColumnReference(1, TableReference.RIGHT)); + ColumnReference expr = new ColumnReference(1, TableReference.RIGHT); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile()) { Assertions.assertThrows(CudfException.class, () -> compiledExpr.computeColumn(t).close()); @@ -73,9 +70,8 @@ public void testInvalidColumnReferenceTransform() { @Test public void testBooleanLiteralTransform() { try (Table t = new Table.TestBuilder().column(true, false, null).build()) { - Literal trueLiteral = Literal.ofBoolean(true); - UnaryExpression trueExpr = new UnaryExpression(UnaryOperator.IDENTITY, trueLiteral); - try (CompiledExpression trueCompiledExpr = trueExpr.compile(); + Literal expr = Literal.ofBoolean(true); + try (CompiledExpression trueCompiledExpr = expr.compile(); ColumnVector trueExprActual = trueCompiledExpr.computeColumn(t); ColumnVector trueExprExpected = ColumnVector.fromBoxedBooleans(true, true, true)) { assertColumnsAreEqual(trueExprExpected, trueExprActual); @@ -83,7 +79,7 @@ public void testBooleanLiteralTransform() { // Uncomment the following after https://github.com/rapidsai/cudf/issues/8831 is fixed // Literal nullLiteral = Literal.ofBoolean(null); - // UnaryExpression nullExpr = new UnaryExpression(AstOperator.IDENTITY, nullLiteral); + // UnaryOperation nullExpr = new UnaryOperation(AstOperator.IDENTITY, nullLiteral); // try (CompiledExpression nullCompiledExpr = nullExpr.compile(); // ColumnVector nullExprActual = nullCompiledExpr.computeColumn(t); // ColumnVector nullExprExpected = ColumnVector.fromBoxedBooleans(null, null, null)) { @@ -97,8 +93,7 @@ public void testBooleanLiteralTransform() { // @NullSource @ValueSource(bytes = 0x12) public void testByteLiteralTransform(Byte value) { - Literal literal = Literal.ofByte(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofByte(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -112,8 +107,7 @@ public void testByteLiteralTransform(Byte value) { // @NullSource @ValueSource(shorts = 0x1234) public void testShortLiteralTransform(Short value) { - Literal literal = Literal.ofShort(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofShort(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -127,8 +121,7 @@ public void testShortLiteralTransform(Short value) { // @NullSource @ValueSource(ints = 0x12345678) public void testIntLiteralTransform(Integer value) { - Literal literal = Literal.ofInt(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofInt(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -142,8 +135,7 @@ public void testIntLiteralTransform(Integer value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testLongLiteralTransform(Long value) { - Literal literal = Literal.ofLong(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofLong(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -157,8 +149,7 @@ public void testLongLiteralTransform(Long value) { // @NullSource @ValueSource(floats = { 123456.789f, Float.NaN, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY} ) public void testFloatLiteralTransform(Float value) { - Literal literal = Literal.ofFloat(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofFloat(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -172,8 +163,7 @@ public void testFloatLiteralTransform(Float value) { // @NullSource @ValueSource(doubles = { 123456.789f, Double.NaN, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY} ) public void testDoubleLiteralTransform(Double value) { - Literal literal = Literal.ofDouble(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofDouble(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -187,8 +177,7 @@ public void testDoubleLiteralTransform(Double value) { // @NullSource @ValueSource(ints = 0x12345678) public void testTimestampDaysLiteralTransform(Integer value) { - Literal literal = Literal.ofTimestampDaysFromInt(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofTimestampDaysFromInt(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -203,8 +192,7 @@ public void testTimestampDaysLiteralTransform(Integer value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testTimestampSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_SECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_SECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -219,8 +207,7 @@ public void testTimestampSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testTimestampMilliSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_MILLISECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_MILLISECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -235,8 +222,7 @@ public void testTimestampMilliSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testTimestampMicroSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_MICROSECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_MICROSECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -251,8 +237,7 @@ public void testTimestampMicroSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testTimestampNanoSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofTimestampFromLong(DType.TIMESTAMP_NANOSECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofTimestampFromLong(DType.TIMESTAMP_NANOSECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -267,8 +252,7 @@ public void testTimestampNanoSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(ints = 0x12345678) public void testDurationDaysLiteralTransform(Integer value) { - Literal literal = Literal.ofDurationDaysFromInt(value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofDurationDaysFromInt(value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -283,8 +267,7 @@ public void testDurationDaysLiteralTransform(Integer value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testDurationSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofDurationFromLong(DType.DURATION_SECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofDurationFromLong(DType.DURATION_SECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -299,8 +282,7 @@ public void testDurationSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testDurationMilliSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofDurationFromLong(DType.DURATION_MILLISECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofDurationFromLong(DType.DURATION_MILLISECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -315,8 +297,7 @@ public void testDurationMilliSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testDurationMicroSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofDurationFromLong(DType.DURATION_MICROSECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofDurationFromLong(DType.DURATION_MICROSECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -331,8 +312,7 @@ public void testDurationMicroSecondsLiteralTransform(Long value) { // @NullSource @ValueSource(longs = 0x1234567890abcdefL) public void testDurationNanoSecondsLiteralTransform(Long value) { - Literal literal = Literal.ofDurationFromLong(DType.DURATION_NANOSECONDS, value); - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, literal); + Literal expr = Literal.ofDurationFromLong(DType.DURATION_NANOSECONDS, value); try (Table t = new Table.TestBuilder().column(5, 4, 3, 2, 1).column(6, 7, 8, null, 10).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -359,7 +339,7 @@ private static ArrayList mapArray(T[] in1, U[] in2, BiFunction createUnaryDoubleExpressionParams() { + private static Stream createUnaryDoubleOperationParams() { Double[] input = new Double[] { -5., 4.5, null, 2.7, 1.5 }; return Stream.of( Arguments.of(UnaryOperator.IDENTITY, input, Arrays.asList(input)), @@ -383,10 +363,10 @@ private static Stream createUnaryDoubleExpressionParams() { } @ParameterizedTest - @MethodSource("createUnaryDoubleExpressionParams") - void testUnaryDoubleExpressionTransform(UnaryOperator op, Double[] input, + @MethodSource("createUnaryDoubleOperationParams") + void testUnaryDoubleOperationTransform(UnaryOperator op, Double[] input, List expectedValues) { - UnaryExpression expr = new UnaryExpression(op, new ColumnReference(0)); + UnaryOperation expr = new UnaryOperation(op, new ColumnReference(0)); try (Table t = new Table.TestBuilder().column(input).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -397,17 +377,17 @@ void testUnaryDoubleExpressionTransform(UnaryOperator op, Double[] input, } @Test - void testUnaryShortExpressionTransform() { + void testUnaryShortOperationTransform() { Short[] input = new Short[] { -5, 4, null, 2, 1 }; try (Table t = new Table.TestBuilder().column(input).build()) { - UnaryExpression expr = new UnaryExpression(UnaryOperator.IDENTITY, new ColumnReference(0)); + ColumnReference expr = new ColumnReference(0); try (CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t)) { assertColumnsAreEqual(t.getColumn(0), actual); } - expr = new UnaryExpression(UnaryOperator.BIT_INVERT, new ColumnReference(0)); - try (CompiledExpression compiledExpr = expr.compile(); + UnaryOperation expr2 = new UnaryOperation(UnaryOperator.BIT_INVERT, new ColumnReference(0)); + try (CompiledExpression compiledExpr = expr2.compile(); ColumnVector actual = compiledExpr.computeColumn(t); ColumnVector expected = ColumnVector.fromBoxedInts(4, -5, null, -3, -2)) { assertColumnsAreEqual(expected, actual); @@ -416,8 +396,8 @@ void testUnaryShortExpressionTransform() { } @Test - void testUnaryLogicalExpressionTransform() { - UnaryExpression expr = new UnaryExpression(UnaryOperator.NOT, new ColumnReference(0)); + void testUnaryLogicalOperationTransform() { + UnaryOperation expr = new UnaryOperation(UnaryOperator.NOT, new ColumnReference(0)); try (Table t = new Table.TestBuilder().column(-5L, 0L, null, 2L, 1L).build(); CompiledExpression compiledExpr = expr.compile(); ColumnVector actual = compiledExpr.computeColumn(t); @@ -426,7 +406,7 @@ void testUnaryLogicalExpressionTransform() { } } - private static Stream createBinaryFloatExpressionParams() { + private static Stream createBinaryFloatOperationParams() { Float[] in1 = new Float[] { -5f, 4.5f, null, 2.7f }; Float[] in2 = new Float[] { 123f, -456f, null, 0f }; return Stream.of( @@ -442,10 +422,10 @@ private static Stream createBinaryFloatExpressionParams() { } @ParameterizedTest - @MethodSource("createBinaryFloatExpressionParams") - void testBinaryFloatExpressionTransform(BinaryOperator op, Float[] in1, Float[] in2, + @MethodSource("createBinaryFloatOperationParams") + void testBinaryFloatOperationTransform(BinaryOperator op, Float[] in1, Float[] in2, List expectedValues) { - BinaryExpression expr = new BinaryExpression(op, + BinaryOperation expr = new BinaryOperation(op, new ColumnReference(0), new ColumnReference(1)); try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); @@ -457,7 +437,7 @@ void testBinaryFloatExpressionTransform(BinaryOperator op, Float[] in1, Float[] } } - private static Stream createBinaryDoublePromotedExpressionParams() { + private static Stream createBinaryDoublePromotedOperationParams() { Float[] in1 = new Float[] { -5f, 4.5f, null, 2.7f }; Float[] in2 = new Float[] { 123f, -456f, null, 0f }; return Stream.of( @@ -468,10 +448,10 @@ private static Stream createBinaryDoublePromotedExpressionParams() { } @ParameterizedTest - @MethodSource("createBinaryDoublePromotedExpressionParams") - void testBinaryDoublePromotedExpressionTransform(BinaryOperator op, Float[] in1, Float[] in2, + @MethodSource("createBinaryDoublePromotedOperationParams") + void testBinaryDoublePromotedOperationTransform(BinaryOperator op, Float[] in1, Float[] in2, List expectedValues) { - BinaryExpression expr = new BinaryExpression(op, + BinaryOperation expr = new BinaryOperation(op, new ColumnReference(0), new ColumnReference(1)); try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); @@ -483,7 +463,7 @@ void testBinaryDoublePromotedExpressionTransform(BinaryOperator op, Float[] in1, } } - private static Stream createBinaryComparisonExpressionParams() { + private static Stream createBinaryComparisonOperationParams() { Integer[] in1 = new Integer[] { -5, 4, null, 2, -3 }; Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 }; return Stream.of( @@ -497,10 +477,10 @@ private static Stream createBinaryComparisonExpressionParams() { } @ParameterizedTest - @MethodSource("createBinaryComparisonExpressionParams") - void testBinaryComparisonExpressionTransform(BinaryOperator op, Integer[] in1, Integer[] in2, + @MethodSource("createBinaryComparisonOperationParams") + void testBinaryComparisonOperationTransform(BinaryOperator op, Integer[] in1, Integer[] in2, List expectedValues) { - BinaryExpression expr = new BinaryExpression(op, + BinaryOperation expr = new BinaryOperation(op, new ColumnReference(0), new ColumnReference(1)); try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); @@ -512,7 +492,7 @@ void testBinaryComparisonExpressionTransform(BinaryOperator op, Integer[] in1, I } } - private static Stream createBinaryBitwiseExpressionParams() { + private static Stream createBinaryBitwiseOperationParams() { Integer[] in1 = new Integer[] { -5, 4, null, 2, -3 }; Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 }; return Stream.of( @@ -522,10 +502,10 @@ private static Stream createBinaryBitwiseExpressionParams() { } @ParameterizedTest - @MethodSource("createBinaryBitwiseExpressionParams") - void testBinaryBitwiseExpressionTransform(BinaryOperator op, Integer[] in1, Integer[] in2, + @MethodSource("createBinaryBitwiseOperationParams") + void testBinaryBitwiseOperationTransform(BinaryOperator op, Integer[] in1, Integer[] in2, List expectedValues) { - BinaryExpression expr = new BinaryExpression(op, + BinaryOperation expr = new BinaryOperation(op, new ColumnReference(0), new ColumnReference(1)); try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); @@ -537,7 +517,7 @@ void testBinaryBitwiseExpressionTransform(BinaryOperator op, Integer[] in1, Inte } } - private static Stream createBinaryBooleanExpressionParams() { + private static Stream createBinaryBooleanOperationParams() { Boolean[] in1 = new Boolean[] { false, true, null, true, false }; Boolean[] in2 = new Boolean[] { true, null, null, true, false }; return Stream.of( @@ -546,10 +526,10 @@ private static Stream createBinaryBooleanExpressionParams() { } @ParameterizedTest - @MethodSource("createBinaryBooleanExpressionParams") - void testBinaryBooleanExpressionTransform(BinaryOperator op, Boolean[] in1, Boolean[] in2, + @MethodSource("createBinaryBooleanOperationParams") + void testBinaryBooleanOperationTransform(BinaryOperator op, Boolean[] in1, Boolean[] in2, List expectedValues) { - BinaryExpression expr = new BinaryExpression(op, + BinaryOperation expr = new BinaryOperation(op, new ColumnReference(0), new ColumnReference(1)); try (Table t = new Table.TestBuilder().column(in1).column(in2).build(); @@ -562,9 +542,9 @@ void testBinaryBooleanExpressionTransform(BinaryOperator op, Boolean[] in1, Bool } @Test - void testMismatchedBinaryExpressionTypes() { + void testMismatchedBinaryOperationTypes() { // verify expression fails to transform if operands are not the same type - BinaryExpression expr = new BinaryExpression(BinaryOperator.ADD, + BinaryOperation expr = new BinaryOperation(BinaryOperator.ADD, new ColumnReference(0), new ColumnReference(1)); try (Table t = new Table.TestBuilder().column(1, 2, 3).column(1L, 2L, 3L).build();