Skip to content

Commit

Permalink
Enable casting to int64, uint64, and double in AST code. (#9379)
Browse files Browse the repository at this point in the history
This PR resolves #8979, adding support for a few casting operators in AST code. These operators can be used to perform operations between columns with mismatched data types without materializing intermediates as new columns.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Conor Hoekstra (https://github.com/codereport)
  - Karthikeyan (https://github.com/karthikeyann)
  - Jason Lowe (https://github.com/jlowe)

URL: #9379
  • Loading branch information
vyasr authored Oct 25, 2021
1 parent 30c31c9 commit ca40e18
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 46 deletions.
29 changes: 29 additions & 0 deletions cpp/include/cudf/ast/detail/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ CUDA_HOST_DEVICE_CALLABLE constexpr void ast_operator_dispatcher(ast_operator op
case ast_operator::NOT:
f.template operator()<ast_operator::NOT>(std::forward<Ts>(args)...);
break;
case ast_operator::CAST_TO_INT64:
f.template operator()<ast_operator::CAST_TO_INT64>(std::forward<Ts>(args)...);
break;
case ast_operator::CAST_TO_UINT64:
f.template operator()<ast_operator::CAST_TO_UINT64>(std::forward<Ts>(args)...);
break;
case ast_operator::CAST_TO_FLOAT64:
f.template operator()<ast_operator::CAST_TO_FLOAT64>(std::forward<Ts>(args)...);
break;
default:
#ifndef __CUDA_ARCH__
CUDF_FAIL("Invalid operator.");
Expand Down Expand Up @@ -780,6 +789,26 @@ struct operator_functor<ast_operator::NOT, false> {
}
};

template <typename To>
struct cast {
static constexpr auto arity{1};
template <typename From>
CUDA_DEVICE_CALLABLE auto operator()(From f) -> decltype(static_cast<To>(f))
{
return static_cast<To>(f);
}
};

template <>
struct operator_functor<ast_operator::CAST_TO_INT64, false> : cast<int64_t> {
};
template <>
struct operator_functor<ast_operator::CAST_TO_UINT64, false> : cast<uint64_t> {
};
template <>
struct operator_functor<ast_operator::CAST_TO_FLOAT64, false> : cast<double> {
};

/*
* The default specialization of nullable operators is to fall back to the non-nullable
* implementation
Expand Down
49 changes: 26 additions & 23 deletions cpp/include/cudf/ast/expressions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,32 @@ enum class ast_operator {
///< NULL_LOGICAL_OR(null, false) is null, and NULL_LOGICAL_OR(valid, valid) ==
///< LOGICAL_OR(valid, valid)
// Unary operators
IDENTITY, ///< Identity function
SIN, ///< Trigonometric sine
COS, ///< Trigonometric cosine
TAN, ///< Trigonometric tangent
ARCSIN, ///< Trigonometric sine inverse
ARCCOS, ///< Trigonometric cosine inverse
ARCTAN, ///< Trigonometric tangent inverse
SINH, ///< Hyperbolic sine
COSH, ///< Hyperbolic cosine
TANH, ///< Hyperbolic tangent
ARCSINH, ///< Hyperbolic sine inverse
ARCCOSH, ///< Hyperbolic cosine inverse
ARCTANH, ///< Hyperbolic tangent inverse
EXP, ///< Exponential (base e, Euler number)
LOG, ///< Natural Logarithm (base e)
SQRT, ///< Square-root (x^0.5)
CBRT, ///< Cube-root (x^(1.0/3))
CEIL, ///< Smallest integer value not less than arg
FLOOR, ///< largest integer value not greater than arg
ABS, ///< Absolute value
RINT, ///< Rounds the floating-point argument arg to an integer value
BIT_INVERT, ///< Bitwise Not (~)
NOT ///< Logical Not (!)
IDENTITY, ///< Identity function
SIN, ///< Trigonometric sine
COS, ///< Trigonometric cosine
TAN, ///< Trigonometric tangent
ARCSIN, ///< Trigonometric sine inverse
ARCCOS, ///< Trigonometric cosine inverse
ARCTAN, ///< Trigonometric tangent inverse
SINH, ///< Hyperbolic sine
COSH, ///< Hyperbolic cosine
TANH, ///< Hyperbolic tangent
ARCSINH, ///< Hyperbolic sine inverse
ARCCOSH, ///< Hyperbolic cosine inverse
ARCTANH, ///< Hyperbolic tangent inverse
EXP, ///< Exponential (base e, Euler number)
LOG, ///< Natural Logarithm (base e)
SQRT, ///< Square-root (x^0.5)
CBRT, ///< Cube-root (x^(1.0/3))
CEIL, ///< Smallest integer value not less than arg
FLOOR, ///< largest integer value not greater than arg
ABS, ///< Absolute value
RINT, ///< Rounds the floating-point argument arg to an integer value
BIT_INVERT, ///< Bitwise Not (~)
NOT, ///< Logical Not (!)
CAST_TO_INT64, ///< Cast value to int64_t
CAST_TO_UINT64, ///< Cast value to uint64_t
CAST_TO_FLOAT64 ///< Cast value to double
};

/**
Expand Down
17 changes: 17 additions & 0 deletions cpp/tests/ast/transform_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ TEST_F(TransformTest, BasicAddition)
cudf::test::expect_columns_equal(expected, result->view(), verbosity);
}

TEST_F(TransformTest, BasicAdditionCast)
{
auto c_0 = column_wrapper<int64_t>{3, 20, 1, 50};
auto c_1 = column_wrapper<int8_t>{10, 7, 20, 0};
auto table = cudf::table_view{{c_0, c_1}};

auto col_ref_0 = cudf::ast::column_reference(0);
auto col_ref_1 = cudf::ast::column_reference(1);
auto cast = cudf::ast::operation(cudf::ast::ast_operator::CAST_TO_INT64, col_ref_1);
auto expression = cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, cast);

auto expected = column_wrapper<int64_t>{13, 27, 21, 50};
auto result = cudf::compute_column(table, expression);

cudf::test::expect_columns_equal(expected, result->view(), verbosity);
}

TEST_F(TransformTest, BasicEquality)
{
auto c_0 = column_wrapper<int32_t>{3, 20, 1, 50};
Expand Down
49 changes: 26 additions & 23 deletions java/src/main/java/ai/rapids/cudf/ast/UnaryOperator.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,32 @@
* NOTE: This must be kept in sync with `jni_to_unary_operator` in CompiledExpression.cpp!
*/
public enum UnaryOperator {
IDENTITY(0), // Identity function
SIN(1), // Trigonometric sine
COS(2), // Trigonometric cosine
TAN(3), // Trigonometric tangent
ARCSIN(4), // Trigonometric sine inverse
ARCCOS(5), // Trigonometric cosine inverse
ARCTAN(6), // Trigonometric tangent inverse
SINH(7), // Hyperbolic sine
COSH(8), // Hyperbolic cosine
TANH(9), // Hyperbolic tangent
ARCSINH(10), // Hyperbolic sine inverse
ARCCOSH(11), // Hyperbolic cosine inverse
ARCTANH(12), // Hyperbolic tangent inverse
EXP(13), // Exponential (base e, Euler number)
LOG(14), // Natural Logarithm (base e)
SQRT(15), // Square-root (x^0.5)
CBRT(16), // Cube-root (x^(1.0/3))
CEIL(17), // Smallest integer value not less than arg
FLOOR(18), // largest integer value not greater than arg
ABS(19), // Absolute value
RINT(20), // Rounds the floating-point argument arg to an integer value
BIT_INVERT(21), // Bitwise Not (~)
NOT(22); // Logical Not (!)
IDENTITY(0), // Identity function
SIN(1), // Trigonometric sine
COS(2), // Trigonometric cosine
TAN(3), // Trigonometric tangent
ARCSIN(4), // Trigonometric sine inverse
ARCCOS(5), // Trigonometric cosine inverse
ARCTAN(6), // Trigonometric tangent inverse
SINH(7), // Hyperbolic sine
COSH(8), // Hyperbolic cosine
TANH(9), // Hyperbolic tangent
ARCSINH(10), // Hyperbolic sine inverse
ARCCOSH(11), // Hyperbolic cosine inverse
ARCTANH(12), // Hyperbolic tangent inverse
EXP(13), // Exponential (base e, Euler number)
LOG(14), // Natural Logarithm (base e)
SQRT(15), // Square-root (x^0.5)
CBRT(16), // Cube-root (x^(1.0/3))
CEIL(17), // Smallest integer value not less than arg
FLOOR(18), // largest integer value not greater than arg
ABS(19), // Absolute value
RINT(20), // Rounds the floating-point argument arg to an integer value
BIT_INVERT(21), // Bitwise Not (~)
NOT(22), // Logical Not (!)
CAST_TO_INT64(23), // Cast value to int64_t
CAST_TO_UINT64(24), // Cast value to uint64_t
CAST_TO_FLOAT64(25); // Cast value to double

private final byte nativeId;

Expand Down
3 changes: 3 additions & 0 deletions java/src/main/native/src/CompiledExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ cudf::ast::ast_operator jni_to_unary_operator(jbyte jni_op_value) {
case 20: return cudf::ast::ast_operator::RINT;
case 21: return cudf::ast::ast_operator::BIT_INVERT;
case 22: return cudf::ast::ast_operator::NOT;
case 23: return cudf::ast::ast_operator::CAST_TO_INT64;
case 24: return cudf::ast::ast_operator::CAST_TO_UINT64;
case 25: return cudf::ast::ast_operator::CAST_TO_FLOAT64;
default: throw std::invalid_argument("unexpected JNI AST unary operator value");
}
}
Expand Down

0 comments on commit ca40e18

Please sign in to comment.