diff --git a/cpp/include/cudf/ast/detail/operators.hpp b/cpp/include/cudf/ast/detail/operators.hpp index 19df8d8e7b6..cffefcaf9cd 100644 --- a/cpp/include/cudf/ast/detail/operators.hpp +++ b/cpp/include/cudf/ast/detail/operators.hpp @@ -192,6 +192,15 @@ CUDA_HOST_DEVICE_CALLABLE constexpr void ast_operator_dispatcher(ast_operator op case ast_operator::NOT: f.template operator()(std::forward(args)...); break; + case ast_operator::CAST_TO_INT64: + f.template operator()(std::forward(args)...); + break; + case ast_operator::CAST_TO_UINT64: + f.template operator()(std::forward(args)...); + break; + case ast_operator::CAST_TO_FLOAT64: + f.template operator()(std::forward(args)...); + break; default: #ifndef __CUDA_ARCH__ CUDF_FAIL("Invalid operator."); @@ -780,6 +789,26 @@ struct operator_functor { } }; +template +struct cast { + static constexpr auto arity{1}; + template + CUDA_DEVICE_CALLABLE auto operator()(From f) -> decltype(static_cast(f)) + { + return static_cast(f); + } +}; + +template <> +struct operator_functor : cast { +}; +template <> +struct operator_functor : cast { +}; +template <> +struct operator_functor : cast { +}; + /* * The default specialization of nullable operators is to fall back to the non-nullable * implementation diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 5454f9a2b95..7ae40a7d65f 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -88,29 +88,32 @@ enum class ast_operator { ///< NULL_LOGICAL_OR(null, false) is null, and NULL_LOGICAL_OR(valid, valid) == ///< LOGICAL_OR(valid, valid) // Unary operators - IDENTITY, ///< Identity function - SIN, ///< Trigonometric sine - COS, ///< Trigonometric cosine - TAN, ///< Trigonometric tangent - ARCSIN, ///< Trigonometric sine inverse - ARCCOS, ///< Trigonometric cosine inverse - ARCTAN, ///< Trigonometric tangent inverse - SINH, ///< Hyperbolic sine - COSH, ///< Hyperbolic cosine - TANH, ///< Hyperbolic tangent - ARCSINH, ///< Hyperbolic sine inverse - ARCCOSH, ///< Hyperbolic cosine inverse - ARCTANH, ///< Hyperbolic tangent inverse - EXP, ///< Exponential (base e, Euler number) - LOG, ///< Natural Logarithm (base e) - SQRT, ///< Square-root (x^0.5) - CBRT, ///< Cube-root (x^(1.0/3)) - CEIL, ///< Smallest integer value not less than arg - FLOOR, ///< largest integer value not greater than arg - ABS, ///< Absolute value - RINT, ///< Rounds the floating-point argument arg to an integer value - BIT_INVERT, ///< Bitwise Not (~) - NOT ///< Logical Not (!) + IDENTITY, ///< Identity function + SIN, ///< Trigonometric sine + COS, ///< Trigonometric cosine + TAN, ///< Trigonometric tangent + ARCSIN, ///< Trigonometric sine inverse + ARCCOS, ///< Trigonometric cosine inverse + ARCTAN, ///< Trigonometric tangent inverse + SINH, ///< Hyperbolic sine + COSH, ///< Hyperbolic cosine + TANH, ///< Hyperbolic tangent + ARCSINH, ///< Hyperbolic sine inverse + ARCCOSH, ///< Hyperbolic cosine inverse + ARCTANH, ///< Hyperbolic tangent inverse + EXP, ///< Exponential (base e, Euler number) + LOG, ///< Natural Logarithm (base e) + SQRT, ///< Square-root (x^0.5) + CBRT, ///< Cube-root (x^(1.0/3)) + CEIL, ///< Smallest integer value not less than arg + FLOOR, ///< largest integer value not greater than arg + ABS, ///< Absolute value + RINT, ///< Rounds the floating-point argument arg to an integer value + BIT_INVERT, ///< Bitwise Not (~) + NOT, ///< Logical Not (!) + CAST_TO_INT64, ///< Cast value to int64_t + CAST_TO_UINT64, ///< Cast value to uint64_t + CAST_TO_FLOAT64 ///< Cast value to double }; /** diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index 175918a0846..8cfd6d24fae 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -109,6 +109,23 @@ TEST_F(TransformTest, BasicAddition) cudf::test::expect_columns_equal(expected, result->view(), verbosity); } +TEST_F(TransformTest, BasicAdditionCast) +{ + auto c_0 = column_wrapper{3, 20, 1, 50}; + auto c_1 = column_wrapper{10, 7, 20, 0}; + auto table = cudf::table_view{{c_0, c_1}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + auto col_ref_1 = cudf::ast::column_reference(1); + auto cast = cudf::ast::operation(cudf::ast::ast_operator::CAST_TO_INT64, col_ref_1); + auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, cast); + + auto expected = column_wrapper{13, 27, 21, 50}; + auto result = cudf::compute_column(table, expression); + + cudf::test::expect_columns_equal(expected, result->view(), verbosity); +} + TEST_F(TransformTest, BasicEquality) { auto c_0 = column_wrapper{3, 20, 1, 50}; diff --git a/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java index 9ef18dbd75d..6fb5a16d888 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java +++ b/java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java @@ -23,29 +23,32 @@ * NOTE: This must be kept in sync with `jni_to_unary_operator` in CompiledExpression.cpp! */ public enum UnaryOperator { - IDENTITY(0), // Identity function - SIN(1), // Trigonometric sine - COS(2), // Trigonometric cosine - TAN(3), // Trigonometric tangent - ARCSIN(4), // Trigonometric sine inverse - ARCCOS(5), // Trigonometric cosine inverse - ARCTAN(6), // Trigonometric tangent inverse - SINH(7), // Hyperbolic sine - COSH(8), // Hyperbolic cosine - TANH(9), // Hyperbolic tangent - ARCSINH(10), // Hyperbolic sine inverse - ARCCOSH(11), // Hyperbolic cosine inverse - ARCTANH(12), // Hyperbolic tangent inverse - EXP(13), // Exponential (base e, Euler number) - LOG(14), // Natural Logarithm (base e) - SQRT(15), // Square-root (x^0.5) - CBRT(16), // Cube-root (x^(1.0/3)) - CEIL(17), // Smallest integer value not less than arg - FLOOR(18), // largest integer value not greater than arg - ABS(19), // Absolute value - RINT(20), // Rounds the floating-point argument arg to an integer value - BIT_INVERT(21), // Bitwise Not (~) - NOT(22); // Logical Not (!) + IDENTITY(0), // Identity function + SIN(1), // Trigonometric sine + COS(2), // Trigonometric cosine + TAN(3), // Trigonometric tangent + ARCSIN(4), // Trigonometric sine inverse + ARCCOS(5), // Trigonometric cosine inverse + ARCTAN(6), // Trigonometric tangent inverse + SINH(7), // Hyperbolic sine + COSH(8), // Hyperbolic cosine + TANH(9), // Hyperbolic tangent + ARCSINH(10), // Hyperbolic sine inverse + ARCCOSH(11), // Hyperbolic cosine inverse + ARCTANH(12), // Hyperbolic tangent inverse + EXP(13), // Exponential (base e, Euler number) + LOG(14), // Natural Logarithm (base e) + SQRT(15), // Square-root (x^0.5) + CBRT(16), // Cube-root (x^(1.0/3)) + CEIL(17), // Smallest integer value not less than arg + FLOOR(18), // largest integer value not greater than arg + ABS(19), // Absolute value + RINT(20), // Rounds the floating-point argument arg to an integer value + BIT_INVERT(21), // Bitwise Not (~) + NOT(22), // Logical Not (!) + CAST_TO_INT64(23), // Cast value to int64_t + CAST_TO_UINT64(24), // Cast value to uint64_t + CAST_TO_FLOAT64(25); // Cast value to double private final byte nativeId; diff --git a/java/src/main/native/src/CompiledExpression.cpp b/java/src/main/native/src/CompiledExpression.cpp index 4b378905a43..a18c88e10dc 100644 --- a/java/src/main/native/src/CompiledExpression.cpp +++ b/java/src/main/native/src/CompiledExpression.cpp @@ -144,6 +144,9 @@ cudf::ast::ast_operator jni_to_unary_operator(jbyte jni_op_value) { case 20: return cudf::ast::ast_operator::RINT; case 21: return cudf::ast::ast_operator::BIT_INVERT; case 22: return cudf::ast::ast_operator::NOT; + case 23: return cudf::ast::ast_operator::CAST_TO_INT64; + case 24: return cudf::ast::ast_operator::CAST_TO_UINT64; + case 25: return cudf::ast::ast_operator::CAST_TO_FLOAT64; default: throw std::invalid_argument("unexpected JNI AST unary operator value"); } }