diff --git a/cpp/include/cudf/ast/detail/expression_evaluator.cuh b/cpp/include/cudf/ast/detail/expression_evaluator.cuh index b6c47afe19d..1c229d2e971 100644 --- a/cpp/include/cudf/ast/detail/expression_evaluator.cuh +++ b/cpp/include/cudf/ast/detail/expression_evaluator.cuh @@ -57,8 +57,11 @@ struct expression_result { /** * Helper function to get the subclass type to dispatch methods to. */ - Subclass& subclass() { return static_cast(*this); } - Subclass const& subclass() const { return static_cast(*this); } + CUDA_DEVICE_CALLABLE Subclass& subclass() { return static_cast(*this); } + CUDA_DEVICE_CALLABLE Subclass const& subclass() const + { + return static_cast(*this); + } // TODO: The index is ignored by the value subclass, but is included in this // signature because it is required by the implementation in the template @@ -70,14 +73,15 @@ struct expression_result { // used, whereas passing it as a parameter keeps it in registers for fast // access at the point where indexing occurs. template - __device__ void set_value(cudf::size_type index, possibly_null_value_t result) + CUDA_DEVICE_CALLABLE void set_value(cudf::size_type index, + possibly_null_value_t const& result) { - subclass()->set_value(); + subclass().template set_value(index, result); } - __device__ bool is_valid() const { subclass()->is_valid(); } + CUDA_DEVICE_CALLABLE bool is_valid() const { return subclass().is_valid(); } - __device__ T value() const { subclass()->value(); } + CUDA_DEVICE_CALLABLE T value() const { return subclass().value(); } }; /** @@ -93,10 +97,11 @@ struct expression_result { template struct value_expression_result : public expression_result, T, has_nulls> { - __device__ value_expression_result() {} + CUDA_DEVICE_CALLABLE value_expression_result() {} template - __device__ void set_value(cudf::size_type index, possibly_null_value_t result) + CUDA_DEVICE_CALLABLE void set_value(cudf::size_type index, + possibly_null_value_t const& result) { if constexpr (std::is_same_v) { _obj = result; @@ -108,7 +113,7 @@ struct value_expression_result /** * @brief Returns true if the underlying data is valid and false otherwise. */ - __device__ bool is_valid() const + CUDA_DEVICE_CALLABLE bool is_valid() const { if constexpr (has_nulls) { return _obj.has_value(); } return true; @@ -120,7 +125,7 @@ struct value_expression_result * If the underlying data is not valid, behavior is undefined. Callers should * use is_valid to check for validity before accessing the value. */ - __device__ T value() const + CUDA_DEVICE_CALLABLE T value() const { // Using two separate constexprs silences compiler warnings, whereas an // if/else does not. An unconditional return is not ignored by the compiler @@ -151,10 +156,13 @@ struct mutable_column_expression_result : public expression_result, mutable_column_device_view, has_nulls> { - __device__ mutable_column_expression_result(mutable_column_device_view& obj) : _obj(obj) {} + CUDA_DEVICE_CALLABLE mutable_column_expression_result(mutable_column_device_view& obj) : _obj(obj) + { + } template - __device__ void set_value(cudf::size_type index, possibly_null_value_t result) + CUDA_DEVICE_CALLABLE void set_value(cudf::size_type index, + possibly_null_value_t const& result) { if constexpr (has_nulls) { if (result.has_value()) { @@ -171,17 +179,19 @@ struct mutable_column_expression_result /** * @brief Not implemented for this specialization. */ - __device__ bool is_valid() const + CUDA_DEVICE_CALLABLE bool is_valid() const { // Not implemented since it would require modifying the API in the parent class to accept an // index. cudf_assert(false && "This method is not implemented."); + // Unreachable return used to silence compiler warnings. + return {}; } /** * @brief Not implemented for this specialization. */ - __device__ mutable_column_device_view value() const + CUDA_DEVICE_CALLABLE mutable_column_device_view value() const { // Not implemented since it would require modifying the API in the parent class to accept an // index. @@ -237,11 +247,10 @@ struct expression_evaluator { * storing intermediates. */ - __device__ expression_evaluator(table_device_view const& left, - table_device_view const& right, - expression_device_view const& plan, - IntermediateDataType* thread_intermediate_storage) - : left(left), right(right), plan(plan), thread_intermediate_storage(thread_intermediate_storage) + CUDA_DEVICE_CALLABLE expression_evaluator(table_device_view const& left, + table_device_view const& right, + expression_device_view const& plan) + : left(left), right(right), plan(plan) { } @@ -253,13 +262,9 @@ struct expression_evaluator { * @param thread_intermediate_storage Pointer to this thread's portion of shared memory for * storing intermediates. */ - __device__ expression_evaluator(table_device_view const& table, - expression_device_view const& plan, - IntermediateDataType* thread_intermediate_storage) - : left(table), - right(table), - plan(plan), - thread_intermediate_storage(thread_intermediate_storage) + CUDA_DEVICE_CALLABLE expression_evaluator(table_device_view const& table, + expression_device_view const& plan) + : expression_evaluator(table, table, plan) { } @@ -277,48 +282,47 @@ struct expression_evaluator { * @return Element The type- and null-resolved data. */ template ())> - __device__ possibly_null_value_t resolve_input( - detail::device_data_reference device_data_reference, + CUDA_DEVICE_CALLABLE possibly_null_value_t resolve_input( + detail::device_data_reference const& input_reference, + IntermediateDataType* thread_intermediate_storage, cudf::size_type left_row_index, thrust::optional right_row_index = {}) const { - auto const data_index = device_data_reference.data_index; - auto const ref_type = device_data_reference.reference_type; // TODO: Everywhere in the code assumes that the table reference is either // left or right. Should we error-check somewhere to prevent // table_reference::OUTPUT from being specified? using ReturnType = possibly_null_value_t; - if (ref_type == detail::device_data_reference_type::COLUMN) { + if (input_reference.reference_type == detail::device_data_reference_type::COLUMN) { // If we have nullable data, return an empty nullable type with no value if the data is null. - auto const& table = - (device_data_reference.table_source == table_reference::LEFT) ? left : right; + auto const& table = (input_reference.table_source == table_reference::LEFT) ? left : right; // Note that the code below assumes that a right index has been passed in - // any case where device_data_reference.table_source == table_reference::RIGHT. + // any case where input_reference.table_source == table_reference::RIGHT. // Otherwise, behavior is undefined. - auto const row_index = (device_data_reference.table_source == table_reference::LEFT) - ? left_row_index - : *right_row_index; + auto const row_index = + (input_reference.table_source == table_reference::LEFT) ? left_row_index : *right_row_index; if constexpr (has_nulls) { - return table.column(data_index).is_valid(row_index) - ? ReturnType(table.column(data_index).element(row_index)) + return table.column(input_reference.data_index).is_valid(row_index) + ? ReturnType(table.column(input_reference.data_index).element(row_index)) : ReturnType(); } else { - return ReturnType(table.column(data_index).element(row_index)); + return ReturnType(table.column(input_reference.data_index).element(row_index)); } - } else if (ref_type == detail::device_data_reference_type::LITERAL) { + } else if (input_reference.reference_type == detail::device_data_reference_type::LITERAL) { if constexpr (has_nulls) { - return plan.literals[data_index].is_valid() - ? ReturnType(plan.literals[data_index].value()) + return plan.literals[input_reference.data_index].is_valid() + ? ReturnType(plan.literals[input_reference.data_index].value()) : ReturnType(); } else { - return ReturnType(plan.literals[data_index].value()); + return ReturnType(plan.literals[input_reference.data_index].value()); } - } else { // Assumes ref_type == detail::device_data_reference_type::INTERMEDIATE + } else { // Assumes input_reference.reference_type == + // detail::device_data_reference_type::INTERMEDIATE // Using memcpy instead of reinterpret_cast for safe type aliasing // Using a temporary variable ensures that the compiler knows the result is aligned - IntermediateDataType intermediate = thread_intermediate_storage[data_index]; + IntermediateDataType intermediate = + thread_intermediate_storage[input_reference.data_index]; ReturnType tmp; memcpy(&tmp, &intermediate, sizeof(ReturnType)); return tmp; @@ -329,8 +333,9 @@ struct expression_evaluator { template ())> - __device__ possibly_null_value_t resolve_input( - detail::device_data_reference device_data_reference, + CUDA_DEVICE_CALLABLE possibly_null_value_t resolve_input( + detail::device_data_reference const& device_data_reference, + IntermediateDataType* thread_intermediate_storage, cudf::size_type left_row_index, thrust::optional right_row_index = {}) const { @@ -352,25 +357,29 @@ struct expression_evaluator { * @param output_row_index The row in the output to insert the result. * @param op The operator to act with. */ - template - __device__ void operator()(OutputType& output_object, - cudf::size_type const input_row_index, - detail::device_data_reference const input, - detail::device_data_reference const output, - cudf::size_type const output_row_index, - ast_operator const op) const + template + CUDA_DEVICE_CALLABLE void operator()( + expression_result& output_object, + cudf::size_type const input_row_index, + detail::device_data_reference const& input, + detail::device_data_reference const& output, + cudf::size_type const output_row_index, + ast_operator const op, + IntermediateDataType* thread_intermediate_storage) const { - auto const typed_input = resolve_input(input, input_row_index); + auto const typed_input = + resolve_input(input, thread_intermediate_storage, input_row_index); ast_operator_dispatcher(op, - unary_expression_output_handler(*this), + unary_expression_output_handler{}, output_object, output_row_index, typed_input, - output); + output, + thread_intermediate_storage); } /** - * @brief Callable to perform a unary operation. + * @brief Callable to perform a binary operation. * * @tparam LHS Type of the left input value. * @tparam RHS Type of the right input value. @@ -385,42 +394,30 @@ struct expression_evaluator { * @param output_row_index The row in the output to insert the result. * @param op The operator to act with. */ - template - __device__ void operator()(OutputType& output_object, - cudf::size_type const left_row_index, - cudf::size_type const right_row_index, - detail::device_data_reference const lhs, - detail::device_data_reference const rhs, - detail::device_data_reference const output, - cudf::size_type const output_row_index, - ast_operator const op) const + template + CUDA_DEVICE_CALLABLE void operator()( + expression_result& output_object, + cudf::size_type const left_row_index, + cudf::size_type const right_row_index, + detail::device_data_reference const& lhs, + detail::device_data_reference const& rhs, + detail::device_data_reference const& output, + cudf::size_type const output_row_index, + ast_operator const op, + IntermediateDataType* thread_intermediate_storage) const { - auto const typed_lhs = resolve_input(lhs, left_row_index, right_row_index); - auto const typed_rhs = resolve_input(rhs, left_row_index, right_row_index); + auto const typed_lhs = + resolve_input(lhs, thread_intermediate_storage, left_row_index, right_row_index); + auto const typed_rhs = + resolve_input(rhs, thread_intermediate_storage, left_row_index, right_row_index); ast_operator_dispatcher(op, - binary_expression_output_handler(*this), + binary_expression_output_handler{}, output_object, output_row_index, typed_lhs, typed_rhs, - output); - } - - template >* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type left_row_index, - cudf::size_type right_row_index, - detail::device_data_reference const lhs, - detail::device_data_reference const rhs, - detail::device_data_reference const output, - cudf::size_type output_row_index, - ast_operator const op) const - { - cudf_assert(false && "Invalid binary dispatch operator for the provided input."); + output, + thread_intermediate_storage); } /** @@ -433,10 +430,13 @@ struct expression_evaluator { * @param output_object The container that data will be inserted into. * @param row_index Row index of all input and output data column(s). */ - template - __device__ void evaluate(OutputType& output_object, cudf::size_type const row_index) + template + CUDA_DEVICE_CALLABLE void evaluate( + expression_result& output_object, + cudf::size_type const row_index, + IntermediateDataType* thread_intermediate_storage) { - evaluate(output_object, row_index, row_index, row_index); + evaluate(output_object, row_index, row_index, row_index, thread_intermediate_storage); } /** @@ -451,11 +451,13 @@ struct expression_evaluator { * @param right_row_index The row to pull the data from the right table. * @param output_row_index The row in the output to insert the result. */ - template - __device__ void evaluate(OutputType& output_object, - cudf::size_type const left_row_index, - cudf::size_type const right_row_index, - cudf::size_type const output_row_index) + template + CUDA_DEVICE_CALLABLE void evaluate( + expression_result& output_object, + cudf::size_type const left_row_index, + cudf::size_type const right_row_index, + cudf::size_type const output_row_index, + IntermediateDataType* thread_intermediate_storage) { cudf::size_type operator_source_index{0}; for (cudf::size_type operator_index = 0; operator_index < plan.operators.size(); @@ -465,9 +467,9 @@ struct expression_evaluator { auto const arity = ast_operator_arity(op); if (arity == 1) { // Unary operator - auto const input = + auto const& input = plan.data_references[plan.operator_source_indices[operator_source_index++]]; - auto const output = + auto const& output = plan.data_references[plan.operator_source_indices[operator_source_index++]]; auto input_row_index = input.table_source == table_reference::LEFT ? left_row_index : right_row_index; @@ -478,14 +480,15 @@ struct expression_evaluator { input, output, output_row_index, - op); + op, + thread_intermediate_storage); } else if (arity == 2) { // Binary operator - auto const lhs = + auto const& lhs = plan.data_references[plan.operator_source_indices[operator_source_index++]]; - auto const rhs = + auto const& rhs = plan.data_references[plan.operator_source_indices[operator_source_index++]]; - auto const output = + auto const& output = plan.data_references[plan.operator_source_indices[operator_source_index++]]; type_dispatcher(lhs.data_type, detail::single_dispatch_binary_operator{}, @@ -497,7 +500,8 @@ struct expression_evaluator { rhs, output, output_row_index, - op); + op, + thread_intermediate_storage); } else { cudf_assert(false && "Invalid operator arity."); } @@ -515,10 +519,7 @@ struct expression_evaluator { */ struct expression_output_handler { public: - __device__ expression_output_handler(expression_evaluator const& evaluator) - : evaluator(evaluator) - { - } + CUDA_DEVICE_CALLABLE expression_output_handler() {} /** * @brief Resolves an output data reference and assigns result value. @@ -536,38 +537,43 @@ struct expression_evaluator { * @param result Value to assign to output. */ template ())> - __device__ void resolve_output(OutputType& output_object, - detail::device_data_reference const device_data_reference, - cudf::size_type const row_index, - possibly_null_value_t const result) const + CUDA_DEVICE_CALLABLE void resolve_output( + expression_result& output_object, + detail::device_data_reference const& device_data_reference, + cudf::size_type const row_index, + IntermediateDataType* thread_intermediate_storage, + possibly_null_value_t const& result) const { - auto const ref_type = device_data_reference.reference_type; - if (ref_type == detail::device_data_reference_type::COLUMN) { + if (device_data_reference.reference_type == detail::device_data_reference_type::COLUMN) { output_object.template set_value(row_index, result); - } else { // Assumes ref_type == detail::device_data_reference_type::INTERMEDIATE + } else { // Assumes device_data_reference.reference_type == + // detail::device_data_reference_type::INTERMEDIATE // Using memcpy instead of reinterpret_cast for safe type aliasing. // Using a temporary variable ensures that the compiler knows the result is aligned. IntermediateDataType tmp; memcpy(&tmp, &result, sizeof(possibly_null_value_t)); - evaluator.thread_intermediate_storage[device_data_reference.data_index] = tmp; + thread_intermediate_storage[device_data_reference.data_index] = tmp; } } template ())> - __device__ void resolve_output(OutputType& output_object, - detail::device_data_reference const device_data_reference, - cudf::size_type const row_index, - possibly_null_value_t const result) const + typename ResultSubclass, + typename T, + bool result_has_nulls, + CUDF_ENABLE_IF(!is_rep_layout_compatible())> + CUDA_DEVICE_CALLABLE void resolve_output( + expression_result& output_object, + detail::device_data_reference const& device_data_reference, + cudf::size_type const row_index, + IntermediateDataType* thread_intermediate_storage, + possibly_null_value_t const& result) const { cudf_assert(false && "Invalid type in resolve_output."); } - - protected: - expression_evaluator const& evaluator; }; /** @@ -578,10 +584,7 @@ struct expression_evaluator { */ template struct unary_expression_output_handler : public expression_output_handler { - __device__ unary_expression_output_handler(expression_evaluator const& evaluator) - : expression_output_handler(evaluator) - { - } + CUDA_DEVICE_CALLABLE unary_expression_output_handler() {} /** * @brief Callable to perform a unary operation. @@ -595,31 +598,42 @@ struct expression_evaluator { * @param output Output data reference. */ template , possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const input, - detail::device_data_reference const output) const + CUDA_DEVICE_CALLABLE void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& input, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { // The output data type is the same whether or not nulls are present, so // pull from the non-nullable operator. using Out = cuda::std::invoke_result_t, Input>; - this->template resolve_output( - output_object, output, output_row_index, detail::operator_functor{}(input)); + this->template resolve_output(output_object, + output, + output_row_index, + thread_intermediate_storage, + detail::operator_functor{}(input)); } template , possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const input, - detail::device_data_reference const output) const + CUDA_DEVICE_CALLABLE void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& input, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { cudf_assert(false && "Invalid unary dispatch operator for the provided input."); } @@ -633,10 +647,7 @@ struct expression_evaluator { */ template struct binary_expression_output_handler : public expression_output_handler { - __device__ binary_expression_output_handler(expression_evaluator const& evaluator) - : expression_output_handler(evaluator) - { - } + CUDA_DEVICE_CALLABLE binary_expression_output_handler() {} /** * @brief Callable to perform a binary operation. @@ -651,16 +662,20 @@ struct expression_evaluator { * @param output Output data reference. */ template , possibly_null_value_t, possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const lhs, - possibly_null_value_t const rhs, - detail::device_data_reference const output) const + CUDA_DEVICE_CALLABLE void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& lhs, + possibly_null_value_t const& rhs, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { // The output data type is the same whether or not nulls are present, so // pull from the non-nullable operator. @@ -668,20 +683,25 @@ struct expression_evaluator { this->template resolve_output(output_object, output, output_row_index, + thread_intermediate_storage, detail::operator_functor{}(lhs, rhs)); } template , possibly_null_value_t, possibly_null_value_t>>* = nullptr> - __device__ void operator()(OutputType& output_object, - cudf::size_type const output_row_index, - possibly_null_value_t const lhs, - possibly_null_value_t const rhs, - detail::device_data_reference output) const + CUDA_DEVICE_CALLABLE void operator()( + expression_result& output_object, + cudf::size_type const output_row_index, + possibly_null_value_t const& lhs, + possibly_null_value_t const& rhs, + detail::device_data_reference const& output, + IntermediateDataType* thread_intermediate_storage) const { cudf_assert(false && "Invalid binary dispatch operator for the provided input."); } @@ -691,9 +711,6 @@ struct expression_evaluator { table_device_view const& right; ///< The right table to operate on. expression_device_view const& plan; ///< The container of device data representing the expression to evaluate. - IntermediateDataType* - thread_intermediate_storage; ///< The shared memory store of intermediates produced during - ///< evaluation. }; } // namespace detail diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index 1f35b54ea61..dc800bde527 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -107,7 +107,6 @@ struct expression_device_view { device_span operators; device_span operator_source_indices; cudf::size_type num_intermediates; - int shmem_per_thread; }; /** @@ -230,6 +229,7 @@ class expression_parser { expression_device_view device_expression_data; ///< The collection of data required to evaluate ///< the expression on the device. + int shmem_per_thread; private: /** @@ -292,7 +292,7 @@ class expression_parser { reinterpret_cast(device_data_buffer_ptr + buffer_offsets[3]), _operator_source_indices.size()); device_expression_data.num_intermediates = _intermediate_counter.get_max_used(); - device_expression_data.shmem_per_thread = static_cast( + shmem_per_thread = static_cast( (_has_nulls ? sizeof(IntermediateDataType) : sizeof(IntermediateDataType)) * device_expression_data.num_intermediates); } diff --git a/cpp/src/join/conditional_join.cu b/cpp/src/join/conditional_join.cu index 1f49ee749ec..f4fe670d3cd 100644 --- a/cpp/src/join/conditional_join.cu +++ b/cpp/src/join/conditional_join.cu @@ -94,8 +94,7 @@ conditional_join(table_view const& left, // Allocate storage for the counter used to get the size of the join output detail::grid_1d config(left_table->num_rows(), DEFAULT_JOIN_BLOCK_SIZE); - auto const shmem_size_per_block = - parser.device_expression_data.shmem_per_thread * config.num_threads_per_block; + auto const shmem_size_per_block = parser.shmem_per_thread * config.num_threads_per_block; join_kind kernel_join_type = join_type == join_kind::FULL_JOIN ? join_kind::LEFT_JOIN : join_type; // If the join size was not provided as an input, compute it here. @@ -229,8 +228,7 @@ std::size_t compute_conditional_join_output_size(table_view const& left, rmm::device_scalar size(0, stream, mr); CHECK_CUDA(stream.value()); detail::grid_1d config(left_table->num_rows(), DEFAULT_JOIN_BLOCK_SIZE); - auto const shmem_size_per_block = - parser.device_expression_data.shmem_per_thread * config.num_threads_per_block; + auto const shmem_size_per_block = parser.shmem_per_thread * config.num_threads_per_block; // Determine number of output rows without actually building the output to simply // find what the size of the output will be. diff --git a/cpp/src/join/conditional_join_kernels.cuh b/cpp/src/join/conditional_join_kernels.cuh index 2ad7c6ad8b8..6dc204441b3 100644 --- a/cpp/src/join/conditional_join_kernels.cuh +++ b/cpp/src/join/conditional_join_kernels.cuh @@ -70,14 +70,15 @@ __global__ void compute_conditional_join_output_size( cudf::size_type const right_num_rows = right_table.num_rows(); auto evaluator = cudf::ast::detail::expression_evaluator( - left_table, right_table, device_expression_data, thread_intermediate_storage); + left_table, right_table, device_expression_data); for (cudf::size_type left_row_index = left_start_idx; left_row_index < left_num_rows; left_row_index += left_stride) { bool found_match = false; for (cudf::size_type right_row_index = 0; right_row_index < right_num_rows; right_row_index++) { auto output_dest = cudf::ast::detail::value_expression_result(); - evaluator.evaluate(output_dest, left_row_index, right_row_index, 0); + evaluator.evaluate( + output_dest, left_row_index, right_row_index, 0, thread_intermediate_storage); if (output_dest.is_valid() && output_dest.value()) { if ((join_type != join_kind::LEFT_ANTI_JOIN) && !(join_type == join_kind::LEFT_SEMI_JOIN && found_match)) { @@ -161,13 +162,14 @@ __global__ void conditional_join(table_device_view left_table, unsigned int const activemask = __ballot_sync(0xffffffff, left_row_index < left_num_rows); auto evaluator = cudf::ast::detail::expression_evaluator( - left_table, right_table, device_expression_data, thread_intermediate_storage); + left_table, right_table, device_expression_data); if (left_row_index < left_num_rows) { bool found_match = false; for (size_type right_row_index(0); right_row_index < right_num_rows; ++right_row_index) { auto output_dest = cudf::ast::detail::value_expression_result(); - evaluator.evaluate(output_dest, left_row_index, right_row_index, 0); + evaluator.evaluate( + output_dest, left_row_index, right_row_index, 0, thread_intermediate_storage); if (output_dest.is_valid() && output_dest.value()) { // If the rows are equal, then we have found a true match diff --git a/cpp/src/transform/compute_column.cu b/cpp/src/transform/compute_column.cu index 742c01b9c60..bf109dbe1e5 100644 --- a/cpp/src/transform/compute_column.cu +++ b/cpp/src/transform/compute_column.cu @@ -69,12 +69,12 @@ __launch_bounds__(max_block_size) __global__ &intermediate_storage[threadIdx.x * device_expression_data.num_intermediates]; auto const start_idx = static_cast(threadIdx.x + blockIdx.x * blockDim.x); auto const stride = static_cast(blockDim.x * gridDim.x); - auto evaluator = cudf::ast::detail::expression_evaluator( - table, device_expression_data, thread_intermediate_storage); + auto evaluator = + cudf::ast::detail::expression_evaluator(table, device_expression_data); for (cudf::size_type row_index = start_idx; row_index < table.num_rows(); row_index += stride) { auto output_dest = ast::detail::mutable_column_expression_result(output_column); - evaluator.evaluate(output_dest, row_index); + evaluator.evaluate(output_dest, row_index, thread_intermediate_storage); } } @@ -107,12 +107,11 @@ std::unique_ptr compute_column(table_view const& table, cudaDeviceGetAttribute(&shmem_limit_per_block, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); auto constexpr MAX_BLOCK_SIZE = 128; auto const block_size = - device_expression_data.shmem_per_thread != 0 - ? std::min(MAX_BLOCK_SIZE, shmem_limit_per_block / device_expression_data.shmem_per_thread) + parser.shmem_per_thread != 0 + ? std::min(MAX_BLOCK_SIZE, shmem_limit_per_block / parser.shmem_per_thread) : MAX_BLOCK_SIZE; - auto const config = cudf::detail::grid_1d{table.num_rows(), block_size}; - auto const shmem_per_block = - device_expression_data.shmem_per_thread * config.num_threads_per_block; + auto const config = cudf::detail::grid_1d{table.num_rows(), block_size}; + auto const shmem_per_block = parser.shmem_per_thread * config.num_threads_per_block; // Execute the kernel auto table_device = table_device_view::create(table, stream);