From 758a2f0f5a5e5f3bec59f62212c1473a49c626d6 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Fri, 27 Aug 2021 14:57:13 -0700 Subject: [PATCH] Change argument parameters to actually use CRTP to enforce the input type. --- .../cudf/ast/detail/expression_evaluator.cuh | 120 ++++++++++-------- 1 file changed, 69 insertions(+), 51 deletions(-) diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index 416d109eb89..8d38fb4064c 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -57,8 +57,8 @@ struct expression_result { /** * Helper function to get the subclass type to dispatch methods to. */ - Subclass& subclass() { return static_cast(*this); } - Subclass const& subclass() const { return static_cast(*this); } + __device__ Subclass& subclass() { return static_cast(*this); } + __device__ Subclass const& subclass() const { return static_cast(*this); } // TODO: The index is ignored by the value subclass, but is included in this // signature because it is required by the implementation in the template @@ -73,10 +73,10 @@ struct expression_result { __device__ void set_value(cudf::size_type index, possibly_null_value_t const& result) { - subclass()->set_value(index, result); + subclass().template set_value(index, result); } - __device__ bool is_valid() const { subclass()->is_valid(); } + __device__ bool is_valid() const { subclass().is_valid(); } __device__ T value() const { subclass()->value(); } }; @@ -349,8 +349,8 @@ struct expression_evaluator { * @param output_row_index The row in the output to insert the result. * @param op The operator to act with. */ - template - __device__ void operator()(OutputType& output_object, + template + __device__ void operator()(expression_result& output_object, cudf::size_type const input_row_index, detail::device_data_reference const& input, detail::device_data_reference const& output, @@ -385,8 +385,8 @@ struct expression_evaluator { * @param output_row_index The row in the output to insert the result. * @param op The operator to act with. */ - template - __device__ void operator()(OutputType& output_object, + template + __device__ void operator()(expression_result& output_object, cudf::size_type const left_row_index, cudf::size_type const right_row_index, detail::device_data_reference const& lhs, @@ -420,8 +420,8 @@ struct expression_evaluator { * @param output_object The container that data will be inserted into. * @param row_index Row index of all input and output data column(s). */ - template - __device__ void evaluate(OutputType& output_object, + template + __device__ void evaluate(expression_result& output_object, cudf::size_type const row_index, IntermediateDataType* thread_intermediate_storage) { @@ -440,8 +440,8 @@ struct expression_evaluator { * @param right_row_index The row to pull the data from the right table. * @param output_row_index The row in the output to insert the result. */ - template - __device__ void evaluate(OutputType& output_object, + template + __device__ void evaluate(expression_result& output_object, cudf::size_type const left_row_index, cudf::size_type const right_row_index, cudf::size_type const output_row_index, @@ -538,13 +538,16 @@ struct expression_evaluator { * @param result Value to assign to output. */ template ())> - __device__ void resolve_output(OutputType& output_object, - detail::device_data_reference const& device_data_reference, - cudf::size_type const row_index, - IntermediateDataType* thread_intermediate_storage, - possibly_null_value_t const& result) const + __device__ void resolve_output( + expression_result& output_object, + detail::device_data_reference const& device_data_reference, + cudf::size_type const row_index, + IntermediateDataType* thread_intermediate_storage, + possibly_null_value_t const& result) const { if (device_data_reference.reference_type == detail::device_data_reference_type::COLUMN) { output_object.template set_value(row_index, result); @@ -558,13 +561,16 @@ struct expression_evaluator { } template ())> - __device__ void resolve_output(OutputType& output_object, - detail::device_data_reference const& device_data_reference, - cudf::size_type const row_index, - IntermediateDataType* thread_intermediate_storage, - possibly_null_value_t const& result) const + typename ResultSubclass, + typename T, + bool result_has_nulls, + CUDF_ENABLE_IF(!is_rep_layout_compatible())> + __device__ void resolve_output( + expression_result& output_object, + detail::device_data_reference const& device_data_reference, + cudf::size_type const row_index, + IntermediateDataType* thread_intermediate_storage, + possibly_null_value_t const& result) const { cudf_assert(false && "Invalid type in resolve_output."); } @@ -592,15 +598,18 @@ struct expression_evaluator { * @param output Output data reference. */ template , possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const& input, - detail::device_data_reference const& output, - IntermediateDataType* thread_intermediate_storage) const + __device__ void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& input, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { // The output data type is the same whether or not nulls are present, so // pull from the non-nullable operator. @@ -613,15 +622,18 @@ struct expression_evaluator { } template , possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const& input, - detail::device_data_reference const& output, - IntermediateDataType* thread_intermediate_storage) const + __device__ void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& input, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { cudf_assert(false && "Invalid unary dispatch operator for the provided input."); } @@ -650,17 +662,20 @@ struct expression_evaluator { * @param output Output data reference. */ template , possibly_null_value_t, possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const& lhs, - possibly_null_value_t const& rhs, - detail::device_data_reference const& output, - IntermediateDataType* thread_intermediate_storage) const + __device__ void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& lhs, + possibly_null_value_t const& rhs, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { // The output data type is the same whether or not nulls are present, so // pull from the non-nullable operator. @@ -673,17 +688,20 @@ struct expression_evaluator { } template , possibly_null_value_t, possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const& lhs, - possibly_null_value_t const& rhs, - detail::device_data_reference const& output, - IntermediateDataType* thread_intermediate_storage) const + __device__ void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& lhs, + possibly_null_value_t const& rhs, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { cudf_assert(false && "Invalid binary dispatch operator for the provided input."); }