Skip to content

Commit

Permalink
Change argument parameters to actually use CRTP to enforce the input …
Browse files Browse the repository at this point in the history
…type.
  • Loading branch information
vyasr committed Sep 9, 2021
1 parent 8bffd58 commit 758a2f0
Showing 1 changed file with 69 additions and 51 deletions.
120 changes: 69 additions & 51 deletions cpp/include/cudf/ast/detail/expression_evaluator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ struct expression_result {
/**
* Helper function to get the subclass type to dispatch methods to.
*/
Subclass& subclass() { return static_cast<Subclass&>(*this); }
Subclass const& subclass() const { return static_cast<Subclass const&>(*this); }
__device__ Subclass& subclass() { return static_cast<Subclass&>(*this); }
__device__ Subclass const& subclass() const { return static_cast<Subclass const&>(*this); }

// TODO: The index is ignored by the value subclass, but is included in this
// signature because it is required by the implementation in the template
Expand All @@ -73,10 +73,10 @@ struct expression_result {
__device__ void set_value(cudf::size_type index,
possibly_null_value_t<Element, has_nulls> const& result)
{
subclass()->set_value(index, result);
subclass().template set_value<Element>(index, result);
}

__device__ bool is_valid() const { subclass()->is_valid(); }
__device__ bool is_valid() const { subclass().is_valid(); }

__device__ T value() const { subclass()->value(); }
};
Expand Down Expand Up @@ -349,8 +349,8 @@ struct expression_evaluator {
* @param output_row_index The row in the output to insert the result.
* @param op The operator to act with.
*/
template <typename Input, typename OutputType>
__device__ void operator()(OutputType& output_object,
template <typename Input, typename ResultSubclass, typename T, bool result_has_nulls>
__device__ void operator()(expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const input_row_index,
detail::device_data_reference const& input,
detail::device_data_reference const& output,
Expand Down Expand Up @@ -385,8 +385,8 @@ struct expression_evaluator {
* @param output_row_index The row in the output to insert the result.
* @param op The operator to act with.
*/
template <typename LHS, typename RHS, typename OutputType>
__device__ void operator()(OutputType& output_object,
template <typename LHS, typename RHS, typename ResultSubclass, typename T, bool result_has_nulls>
__device__ void operator()(expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const left_row_index,
cudf::size_type const right_row_index,
detail::device_data_reference const& lhs,
Expand Down Expand Up @@ -420,8 +420,8 @@ struct expression_evaluator {
* @param output_object The container that data will be inserted into.
* @param row_index Row index of all input and output data column(s).
*/
template <typename OutputType>
__device__ void evaluate(OutputType& output_object,
template <typename ResultSubclass, typename T, bool result_has_nulls>
__device__ void evaluate(expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const row_index,
IntermediateDataType<has_nulls>* thread_intermediate_storage)
{
Expand All @@ -440,8 +440,8 @@ struct expression_evaluator {
* @param right_row_index The row to pull the data from the right table.
* @param output_row_index The row in the output to insert the result.
*/
template <typename OutputType>
__device__ void evaluate(OutputType& output_object,
template <typename ResultSubclass, typename T, bool result_has_nulls>
__device__ void evaluate(expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const left_row_index,
cudf::size_type const right_row_index,
cudf::size_type const output_row_index,
Expand Down Expand Up @@ -538,13 +538,16 @@ struct expression_evaluator {
* @param result Value to assign to output.
*/
template <typename Element,
typename OutputType,
typename ResultSubclass,
typename T,
bool result_has_nulls,
CUDF_ENABLE_IF(is_rep_layout_compatible<Element>())>
__device__ void resolve_output(OutputType& output_object,
detail::device_data_reference const& device_data_reference,
cudf::size_type const row_index,
IntermediateDataType<has_nulls>* thread_intermediate_storage,
possibly_null_value_t<Element, has_nulls> const& result) const
__device__ void resolve_output(
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
detail::device_data_reference const& device_data_reference,
cudf::size_type const row_index,
IntermediateDataType<has_nulls>* thread_intermediate_storage,
possibly_null_value_t<Element, has_nulls> const& result) const
{
if (device_data_reference.reference_type == detail::device_data_reference_type::COLUMN) {
output_object.template set_value<Element>(row_index, result);
Expand All @@ -558,13 +561,16 @@ struct expression_evaluator {
}

template <typename Element,
typename OutputType,
CUDF_ENABLE_IF(not is_rep_layout_compatible<Element>())>
__device__ void resolve_output(OutputType& output_object,
detail::device_data_reference const& device_data_reference,
cudf::size_type const row_index,
IntermediateDataType<has_nulls>* thread_intermediate_storage,
possibly_null_value_t<Element, has_nulls> const& result) const
typename ResultSubclass,
typename T,
bool result_has_nulls,
CUDF_ENABLE_IF(!is_rep_layout_compatible<Element>())>
__device__ void resolve_output(
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
detail::device_data_reference const& device_data_reference,
cudf::size_type const row_index,
IntermediateDataType<has_nulls>* thread_intermediate_storage,
possibly_null_value_t<Element, has_nulls> const& result) const
{
cudf_assert(false && "Invalid type in resolve_output.");
}
Expand Down Expand Up @@ -592,15 +598,18 @@ struct expression_evaluator {
* @param output Output data reference.
*/
template <ast_operator op,
typename OutputType,
typename ResultSubclass,
typename T,
bool result_has_nulls,
std::enable_if_t<
detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
possibly_null_value_t<Input, has_nulls>>>* = nullptr>
__device__ void operator()(OutputType& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<Input, has_nulls> const& input,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
__device__ void operator()(
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<Input, has_nulls> const& input,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
{
// The output data type is the same whether or not nulls are present, so
// pull from the non-nullable operator.
Expand All @@ -613,15 +622,18 @@ struct expression_evaluator {
}

template <ast_operator op,
typename OutputType,
typename ResultSubclass,
typename T,
bool result_has_nulls,
std::enable_if_t<
!detail::is_valid_unary_op<detail::operator_functor<op, has_nulls>,
possibly_null_value_t<Input, has_nulls>>>* = nullptr>
__device__ void operator()(OutputType& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<Input, has_nulls> const& input,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
__device__ void operator()(
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<Input, has_nulls> const& input,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
{
cudf_assert(false && "Invalid unary dispatch operator for the provided input.");
}
Expand Down Expand Up @@ -650,17 +662,20 @@ struct expression_evaluator {
* @param output Output data reference.
*/
template <ast_operator op,
typename OutputType,
typename ResultSubclass,
typename T,
bool result_has_nulls,
std::enable_if_t<detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
possibly_null_value_t<LHS, has_nulls>,
possibly_null_value_t<RHS, has_nulls>>>* =
nullptr>
__device__ void operator()(OutputType& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<LHS, has_nulls> const& lhs,
possibly_null_value_t<RHS, has_nulls> const& rhs,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
__device__ void operator()(
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<LHS, has_nulls> const& lhs,
possibly_null_value_t<RHS, has_nulls> const& rhs,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
{
// The output data type is the same whether or not nulls are present, so
// pull from the non-nullable operator.
Expand All @@ -673,17 +688,20 @@ struct expression_evaluator {
}

template <ast_operator op,
typename OutputType,
typename ResultSubclass,
typename T,
bool result_has_nulls,
std::enable_if_t<
!detail::is_valid_binary_op<detail::operator_functor<op, has_nulls>,
possibly_null_value_t<LHS, has_nulls>,
possibly_null_value_t<RHS, has_nulls>>>* = nullptr>
__device__ void operator()(OutputType& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<LHS, has_nulls> const& lhs,
possibly_null_value_t<RHS, has_nulls> const& rhs,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
__device__ void operator()(
expression_result<ResultSubclass, T, result_has_nulls>& output_object,
cudf::size_type const output_row_index,
possibly_null_value_t<LHS, has_nulls> const& lhs,
possibly_null_value_t<RHS, has_nulls> const& rhs,
detail::device_data_reference const& output,
IntermediateDataType<has_nulls>* thread_intermediate_storage) const
{
cudf_assert(false && "Invalid binary dispatch operator for the provided input.");
}
Expand Down

0 comments on commit 758a2f0

Please sign in to comment.