Skip to content

Commit

Permalink
Move the binary_ops common dispatcher logic to be executed on the CPU (
Browse files Browse the repository at this point in the history
…#9816)

* move NullEquals to separate file

* To improve runtime performance move more binary_ops dispatch to host

* make sure to forceinline the operator_dispatcher

* Correct style issues found by ci

* Expand the binary-op compiled benchmark suite

* Ensure forceinline is on binary ops device dispatch functions

* Correct style issues found by ci

Co-authored-by: Karthikeyan Natarajan <[email protected]>
Co-authored-by: Karthikeyan <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2021
1 parent 74ac6ed commit 69e6dbb
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 72 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ add_library(
src/binaryop/compiled/LogicalOr.cu
src/binaryop/compiled/Mod.cu
src/binaryop/compiled/Mul.cu
src/binaryop/compiled/NullEquals.cu
src/binaryop/compiled/NullMax.cu
src/binaryop/compiled/NullMin.cu
src/binaryop/compiled/PMod.cu
Expand Down
66 changes: 36 additions & 30 deletions cpp/benchmarks/binaryop/compiled_binaryop_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ void BM_compiled_binaryop(benchmark::State& state, cudf::binary_operator binop)
}

// TODO tparam boolean for null.
#define BINARYOP_BENCHMARK_DEFINE(TypeLhs, TypeRhs, binop, TypeOut) \
#define BINARYOP_BENCHMARK_DEFINE(name, TypeLhs, TypeRhs, binop, TypeOut) \
BENCHMARK_TEMPLATE_DEFINE_F( \
COMPILED_BINARYOP, binop, TypeLhs, TypeRhs, TypeOut, cudf::binary_operator::binop) \
COMPILED_BINARYOP, name, TypeLhs, TypeRhs, TypeOut, cudf::binary_operator::binop) \
(::benchmark::State & st) \
{ \
BM_compiled_binaryop<TypeLhs, TypeRhs, TypeOut>(st, cudf::binary_operator::binop); \
} \
BENCHMARK_REGISTER_F(COMPILED_BINARYOP, binop) \
BENCHMARK_REGISTER_F(COMPILED_BINARYOP, name) \
->Unit(benchmark::kMicrosecond) \
->UseManualTime() \
->Arg(10000) /* 10k */ \
Expand All @@ -70,30 +70,36 @@ using namespace cudf;
using namespace numeric;

// clang-format off
BINARYOP_BENCHMARK_DEFINE(float, int64_t, ADD, int32_t);
BINARYOP_BENCHMARK_DEFINE(duration_s, duration_D, SUB, duration_ms);
BINARYOP_BENCHMARK_DEFINE(float, float, MUL, int64_t);
BINARYOP_BENCHMARK_DEFINE(int64_t, int64_t, DIV, int64_t);
BINARYOP_BENCHMARK_DEFINE(int64_t, int64_t, TRUE_DIV, int64_t);
BINARYOP_BENCHMARK_DEFINE(int64_t, int64_t, FLOOR_DIV, int64_t);
BINARYOP_BENCHMARK_DEFINE(double, double, MOD, double);
BINARYOP_BENCHMARK_DEFINE(int32_t, int64_t, PMOD, double);
BINARYOP_BENCHMARK_DEFINE(int32_t, uint8_t, PYMOD, int64_t);
BINARYOP_BENCHMARK_DEFINE(int64_t, int64_t, POW, double);
BINARYOP_BENCHMARK_DEFINE(float, double, LOG_BASE, double);
BINARYOP_BENCHMARK_DEFINE(float, double, ATAN2, double);
BINARYOP_BENCHMARK_DEFINE(int, int, SHIFT_LEFT, int);
BINARYOP_BENCHMARK_DEFINE(int16_t, int64_t, SHIFT_RIGHT, int);
BINARYOP_BENCHMARK_DEFINE(int64_t, int32_t, SHIFT_RIGHT_UNSIGNED, int64_t);
BINARYOP_BENCHMARK_DEFINE(int64_t, int32_t, BITWISE_AND, int16_t);
BINARYOP_BENCHMARK_DEFINE(int16_t, int32_t, BITWISE_OR, int64_t);
BINARYOP_BENCHMARK_DEFINE(int16_t, int64_t, BITWISE_XOR, int32_t);
BINARYOP_BENCHMARK_DEFINE(double, int8_t, LOGICAL_AND, bool);
BINARYOP_BENCHMARK_DEFINE(int16_t, int64_t, LOGICAL_OR, bool);
BINARYOP_BENCHMARK_DEFINE(duration_ms, duration_ns, EQUAL, bool);
BINARYOP_BENCHMARK_DEFINE(decimal32, decimal32, NOT_EQUAL, bool);
BINARYOP_BENCHMARK_DEFINE(timestamp_s, timestamp_s, LESS, bool);
BINARYOP_BENCHMARK_DEFINE(timestamp_ms, timestamp_s, GREATER, bool);
BINARYOP_BENCHMARK_DEFINE(duration_ms, duration_ns, NULL_EQUALS, bool);
BINARYOP_BENCHMARK_DEFINE(decimal32, decimal32, NULL_MAX, decimal32);
BINARYOP_BENCHMARK_DEFINE(timestamp_D, timestamp_s, NULL_MIN, timestamp_s);
BINARYOP_BENCHMARK_DEFINE(ADD_1, float, float, ADD, float);
BINARYOP_BENCHMARK_DEFINE(ADD_2, timestamp_s, duration_s, ADD, timestamp_s);
BINARYOP_BENCHMARK_DEFINE(SUB_1, duration_s, duration_D, SUB, duration_ms);
BINARYOP_BENCHMARK_DEFINE(SUB_2, int64_t, int64_t, SUB, int64_t);
BINARYOP_BENCHMARK_DEFINE(MUL_1, float, float, MUL, int64_t);
BINARYOP_BENCHMARK_DEFINE(MUL_2, duration_s, int64_t, MUL, duration_s);
BINARYOP_BENCHMARK_DEFINE(DIV_1, int64_t, int64_t, DIV, int64_t);
BINARYOP_BENCHMARK_DEFINE(DIV_2, duration_ms, int32_t, DIV, duration_ms);
BINARYOP_BENCHMARK_DEFINE(TRUE_DIV, int64_t, int64_t, TRUE_DIV, int64_t);
BINARYOP_BENCHMARK_DEFINE(FLOOR_DIV, int64_t, int64_t, FLOOR_DIV, int64_t);
BINARYOP_BENCHMARK_DEFINE(MOD_1, double, double, MOD, double);
BINARYOP_BENCHMARK_DEFINE(MOD_2, duration_ms, int64_t, MOD, duration_ms);
BINARYOP_BENCHMARK_DEFINE(PMOD, int32_t, int64_t, PMOD, double);
BINARYOP_BENCHMARK_DEFINE(PYMOD, int32_t, uint8_t, PYMOD, int64_t);
BINARYOP_BENCHMARK_DEFINE(POW, int64_t, int64_t, POW, double);
BINARYOP_BENCHMARK_DEFINE(LOG_BASE, float, double, LOG_BASE, double);
BINARYOP_BENCHMARK_DEFINE(ATAN2, float, double, ATAN2, double);
BINARYOP_BENCHMARK_DEFINE(SHIFT_LEFT, int, int, SHIFT_LEFT, int);
BINARYOP_BENCHMARK_DEFINE(SHIFT_RIGHT, int16_t, int64_t, SHIFT_RIGHT, int);
BINARYOP_BENCHMARK_DEFINE(USHIFT_RIGHT, int64_t, int32_t, SHIFT_RIGHT_UNSIGNED, int64_t);
BINARYOP_BENCHMARK_DEFINE(BITWISE_AND, int64_t, int32_t, BITWISE_AND, int16_t);
BINARYOP_BENCHMARK_DEFINE(BITWISE_OR, int16_t, int32_t, BITWISE_OR, int64_t);
BINARYOP_BENCHMARK_DEFINE(BITWISE_XOR, int16_t, int64_t, BITWISE_XOR, int32_t);
BINARYOP_BENCHMARK_DEFINE(LOGICAL_AND, double, int8_t, LOGICAL_AND, bool);
BINARYOP_BENCHMARK_DEFINE(LOGICAL_OR, int16_t, int64_t, LOGICAL_OR, bool);
BINARYOP_BENCHMARK_DEFINE(EQUAL_1, int32_t, int64_t, EQUAL, bool);
BINARYOP_BENCHMARK_DEFINE(EQUAL_2, duration_ms, duration_ns, EQUAL, bool);
BINARYOP_BENCHMARK_DEFINE(NOT_EQUAL, decimal32, decimal32, NOT_EQUAL, bool);
BINARYOP_BENCHMARK_DEFINE(LESS, timestamp_s, timestamp_s, LESS, bool);
BINARYOP_BENCHMARK_DEFINE(GREATER, timestamp_ms, timestamp_s, GREATER, bool);
BINARYOP_BENCHMARK_DEFINE(NULL_EQUALS, duration_ms, duration_ns, NULL_EQUALS, bool);
BINARYOP_BENCHMARK_DEFINE(NULL_MAX, decimal32, decimal32, NULL_MAX, decimal32);
BINARYOP_BENCHMARK_DEFINE(NULL_MIN, timestamp_D, timestamp_s, NULL_MIN, timestamp_s);
14 changes: 6 additions & 8 deletions cpp/include/cudf/utilities/type_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ template <typename T1>
struct double_type_dispatcher_second_type {
#pragma nv_exec_check_disable
template <typename T2, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(F&& f, Ts&&... args) const
CUDF_HDFI decltype(auto) operator()(F&& f, Ts&&... args) const
{
return f.template operator()<T1, T2>(std::forward<Ts>(args)...);
}
Expand All @@ -541,9 +541,7 @@ template <template <cudf::type_id> typename IdTypeMap>
struct double_type_dispatcher_first_type {
#pragma nv_exec_check_disable
template <typename T1, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(cudf::data_type type2,
F&& f,
Ts&&... args) const
CUDF_HDFI decltype(auto) operator()(cudf::data_type type2, F&& f, Ts&&... args) const
{
return type_dispatcher<IdTypeMap>(type2,
detail::double_type_dispatcher_second_type<T1>{},
Expand All @@ -568,10 +566,10 @@ struct double_type_dispatcher_first_type {
*/
#pragma nv_exec_check_disable
template <template <cudf::type_id> typename IdTypeMap = id_to_type_impl, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE constexpr decltype(auto) double_type_dispatcher(cudf::data_type type1,
cudf::data_type type2,
F&& f,
Ts&&... args)
CUDF_HDFI constexpr decltype(auto) double_type_dispatcher(cudf::data_type type1,
cudf::data_type type2,
F&& f,
Ts&&... args)
{
return type_dispatcher<IdTypeMap>(type1,
detail::double_type_dispatcher_first_type<IdTypeMap>{},
Expand Down
26 changes: 26 additions & 0 deletions cpp/src/binaryop/compiled/NullEquals.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "binary_ops.cuh"

namespace cudf::binops::compiled {
template void apply_binary_op<ops::NullEquals>(mutable_column_device_view&,
column_device_view const&,
column_device_view const&,
bool is_lhs_scalar,
bool is_rhs_scalar,
rmm::cuda_stream_view);
} // namespace cudf::binops::compiled
2 changes: 1 addition & 1 deletion cpp/src/binaryop/compiled/binary_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ case binary_operator::PYMOD: apply_binary_op<ops::PyMod>(out, lhs
case binary_operator::POW: apply_binary_op<ops::Pow>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
case binary_operator::EQUAL:
case binary_operator::NOT_EQUAL:
case binary_operator::NULL_EQUALS:
if(out.type().id() != type_id::BOOL8) CUDF_FAIL("Output type of Comparison operator should be bool type");
dispatch_equality_op(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, op, stream); break;
case binary_operator::LESS: apply_binary_op<ops::Less>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
Expand All @@ -337,6 +336,7 @@ case binary_operator::SHIFT_RIGHT_UNSIGNED: apply_binary_op<ops::ShiftRightUnsig
case binary_operator::LOG_BASE: apply_binary_op<ops::LogBase>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
case binary_operator::ATAN2: apply_binary_op<ops::ATan2>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
case binary_operator::PMOD: apply_binary_op<ops::PMod>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
case binary_operator::NULL_EQUALS: apply_binary_op<ops::NullEquals>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
case binary_operator::NULL_MAX: apply_binary_op<ops::NullMax>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
case binary_operator::NULL_MIN: apply_binary_op<ops::NullMin>(out, lhs, rhs, is_lhs_scalar, is_rhs_scalar, stream); break;
default:;
Expand Down
63 changes: 44 additions & 19 deletions cpp/src/binaryop/compiled/binary_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -177,35 +177,51 @@ struct ops2_wrapper {
};

/**
* @brief Functor which does single, and double type dispatcher in device code
* @brief Functor which does single type dispatcher in device code
*
* single type dispatcher for lhs and rhs with common types.
*
* @tparam BinaryOperator binary operator functor
*/
template <class BinaryOperator>
struct binary_op_device_dispatcher {
data_type common_data_type;
mutable_column_device_view out;
column_device_view lhs;
column_device_view rhs;
bool is_lhs_scalar;
bool is_rhs_scalar;

__forceinline__ __device__ void operator()(size_type i)
{
type_dispatcher(common_data_type,
ops_wrapper<BinaryOperator>{out, lhs, rhs, is_lhs_scalar, is_rhs_scalar},
i);
}
};

/**
* @brief Functor which does double type dispatcher in device code
*
* double type dispatcher for lhs and rhs without common types.
*
* @tparam BinaryOperator binary operator functor
*/
template <class BinaryOperator>
struct device_type_dispatcher {
struct binary_op_double_device_dispatcher {
mutable_column_device_view out;
column_device_view lhs;
column_device_view rhs;
bool is_lhs_scalar;
bool is_rhs_scalar;
std::optional<data_type> common_data_type;

__device__ void operator()(size_type i)
__forceinline__ __device__ void operator()(size_type i)
{
if (common_data_type) {
type_dispatcher(*common_data_type,
ops_wrapper<BinaryOperator>{out, lhs, rhs, is_lhs_scalar, is_rhs_scalar},
i);
} else {
double_type_dispatcher(
lhs.type(),
rhs.type(),
ops2_wrapper<BinaryOperator>{out, lhs, rhs, is_lhs_scalar, is_rhs_scalar},
i);
}
double_type_dispatcher(
lhs.type(),
rhs.type(),
ops2_wrapper<BinaryOperator>{out, lhs, rhs, is_lhs_scalar, is_rhs_scalar},
i);
}
};

Expand Down Expand Up @@ -263,10 +279,19 @@ void apply_binary_op(mutable_column_device_view& outd,
auto common_dtype = get_common_type(outd.type(), lhsd.type(), rhsd.type());

// Create binop functor instance
auto binop_func = device_type_dispatcher<BinaryOperator>{
outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar, common_dtype};
// Execute it on every element
for_each(stream, outd.size(), binop_func);
if (common_dtype) {
// Execute it on every element
for_each(stream,
outd.size(),
binary_op_device_dispatcher<BinaryOperator>{
*common_dtype, outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar});
} else {
// Execute it on every element
for_each(stream,
outd.size(),
binary_op_double_device_dispatcher<BinaryOperator>{
outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar});
}
}

} // namespace compiled
Expand Down
41 changes: 27 additions & 14 deletions cpp/src/binaryop/compiled/equality_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,32 @@ void dispatch_equality_op(mutable_column_device_view& outd,
auto common_dtype = get_common_type(outd.type(), lhsd.type(), rhsd.type());

// Execute it on every element
for_each(
stream,
outd.size(),
[op, outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar, common_dtype] __device__(size_type i) {
// clang-format off
// Similar enabled template types should go together (better performance)
switch (op) {
case binary_operator::EQUAL: device_type_dispatcher<ops::Equal>{outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar, common_dtype}(i); break;
case binary_operator::NOT_EQUAL: device_type_dispatcher<ops::NotEqual>{outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar, common_dtype}(i); break;
case binary_operator::NULL_EQUALS: device_type_dispatcher<ops::NullEquals>{outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar, common_dtype}(i); break;
default:;
}
// clang-format on
});

if (common_dtype) {
if (op == binary_operator::EQUAL) {
for_each(stream,
outd.size(),
binary_op_device_dispatcher<ops::Equal>{
*common_dtype, outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar});
} else if (op == binary_operator::NOT_EQUAL) {
for_each(stream,
outd.size(),
binary_op_device_dispatcher<ops::NotEqual>{
*common_dtype, outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar});
}
} else {
if (op == binary_operator::EQUAL) {
for_each(stream,
outd.size(),
binary_op_double_device_dispatcher<ops::Equal>{
outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar});
} else if (op == binary_operator::NOT_EQUAL) {
for_each(stream,
outd.size(),
binary_op_double_device_dispatcher<ops::NotEqual>{
outd, lhsd, rhsd, is_lhs_scalar, is_rhs_scalar});
}
}
}

} // namespace cudf::binops::compiled

0 comments on commit 69e6dbb

Please sign in to comment.